In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import Tensor
from torch import optim
from torch.optim.adam import Adam

from torchvision.datasets import CIFAR10 # type: ignore
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda # type: ignore

import utils

import matplotlib.pyplot as plt

In [2]:
class RGB2YUV:
    @staticmethod
    def __call__(rgb: Tensor) -> Tensor:     
        m = torch.tensor([[0.29900, -0.16874,  0.50000],
                          [0.58700, -0.33126, -0.41869],
                          [0.11400,  0.50000, -0.08131]])
        
        yuv = (rgb.permute(1, 2, 0) @ m).permute(2, 0, 1)
        yuv[1:, :, :] += 0.5
        return yuv


class Normalization:
    m: Tensor
    s: Tensor

    def __init__(self, kernel_size: int = 7, sigma: float = 1.0) -> None:
        self.kernel_size = kernel_size
        self.sigma = sigma

    @property
    def gaussian_kernel(self) -> Tensor:
        x = torch.arange(self.kernel_size).float() - self.kernel_size // 2
        gaussian_1d = torch.exp(-0.5 * (x / self.sigma).pow(2))
        gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
        gaussian_2d /= gaussian_2d.sum()
        return gaussian_2d.expand(1, 1, -1, -1)

    def fit(self, imgs: Tensor) -> None:
        uv_channels = imgs[:, 1:]
        self.m = uv_channels.mean((0, 2, 3), keepdims=True)
        self.s = uv_channels.std((0, 2, 3), keepdims=True)

    def transform(self, imgs: Tensor) -> Tensor:
        y_channel = imgs[:, :1]
        uv_channels = imgs[:, 1:]
        
        y_blurred = torch.nn.functional.conv2d(y_channel, self.gaussian_kernel, padding=self.kernel_size // 2)
        y_normalized = y_channel - y_blurred
        y_normalized /= y_normalized.std((2, 3), keepdims=True) + 1e-5
        
        uv_normalized = (uv_channels - self.m) / self.s

        return torch.cat([y_normalized, uv_normalized], 1)

    def fit_transform(self, imgs: Tensor) -> Tensor:
        self.fit(imgs)
        return self.transform(imgs)


class StandardScaler:
    m: Tensor
    s: Tensor

    def __init__(self, dims: int | tuple[int, ...]):
        self.dims = dims

    def fit(self, t: torch.Tensor):
        self.m = t.mean(self.dims, keepdims=True)
        self.s = t.std(self.dims, keepdims=True)

    def transform(self, t: Tensor):
        return (t - self.m) / self.s

    def fit_transform(self, t: Tensor):
        self.fit(t)
        return self.transform(t)


def get_images(dataset):
    return torch.stack([el[0] for el in dataset])


def get_targets(dataset):
    return torch.tensor([el[1] for el in dataset])


transform = Compose([
    ToTensor(),
    RGB2YUV(),
])

raw_train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
raw_test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

norm = Normalization()
train_dataset = TensorDataset(norm.fit_transform(get_images(raw_train_dataset)), get_targets(raw_train_dataset))
test_dataset = TensorDataset(norm.transform(get_images(raw_test_dataset)), get_targets(raw_test_dataset))

# transform = Compose([
#     # transforms.Resize((224, 224)),
#     ToTensor(),
#     Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize as required
# ])

# train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
# test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_loader: DataLoader[tuple[Tensor, Tensor]] = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
test_loader: DataLoader[tuple[Tensor, Tensor]] = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [4]:
cfgs = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [5]:
vgg = utils.VGG(cfgs['VGG16'])

optimizer = optim.SGD(vgg.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)

In [6]:
trainer = utils.Trainer(vgg, train_loader, test_loader, optimizer, scheduler, epochs=300)
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mantonii-belyshev[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
accuracy,▁▅▇▇▇▇▇█████████████████████████████████
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
eval_loss,█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
loss,█▅▄▄▃▃▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,89.33
epoch,299.0
eval_loss,0.58378
loss,0.00142


In [7]:
trainer.test(300)

{'eval_loss': 0.5837773978710175, 'accuracy': 89.33, 'epoch': 300}

In [8]:
# torch.save(vgg.state_dict(), 'vgg16-warmup.pth')
torch.save(vgg.state_dict(), f"vgg16-{trainer.test(300)['accuracy']}%.pth")

In [8]:
trainer.test(300)['accuracy']

87.8

In [9]:
# # torch.save(vgg.state_dict(), 'vgg16-88.5%.pth')
# vgg = utils.VGG(cfgs['VGG16'])
# vgg.load_state_dict(torch.load('vgg16-88.5%.pth'))

In [10]:
# bayesian_vgg = utils.BayesianVGG(cfgs['VGG16'])
# bayesian_vgg.from_vgg(vgg.cpu())

# optimizer = optim.SGD(bayesian_vgg.parameters(), lr=0.01, momentum=0.9)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.98)


# def reg_coef_lambda(epoch: int) -> float:
#     return 1
#     if epoch < 50:
#         return epoch / 50
#     else:
#         return 1.

In [11]:
# trainer = utils.Trainer(bayesian_vgg, train_loader, test_loader, optimizer, scheduler, epochs=300, reg_coef_lambda=reg_coef_lambda)
# trainer.train()

In [12]:
# trainer.test(0)