## Model Pre Training with augmentations and random labels


In [1]:
from datetime import datetime

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from models.model import ConvModel
from models.pre_trainer import PreTrainer

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    gpu_index = 0
else:
    device = torch.device('cpu')
    gpu_index = -1

In [3]:
color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)

composed_transforms = transforms.Compose([
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor()])

In [4]:
train_dataset = datasets.CIFAR10('./datasets',  download=True, train=True, transform=composed_transforms)
g = torch.Generator()
g.manual_seed(0)

Files already downloaded and verified


<torch._C.Generator at 0x19c5bc62370>

In [5]:
train_loader_1 = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False, pin_memory=True, drop_last=True, generator=g)

In [6]:
model = ConvModel(out_dim=10, dataset='CIFAR10')



In [7]:
optimizer = torch.optim.Adam(model.parameters(), 0.0003, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader_1), eta_min=0,
                                                       last_epoch=-1)

with torch.cuda.device(gpu_index):
    pre_trainer = PreTrainer(model=model, optimizer=optimizer, scheduler=scheduler, device=device, epochs=200, log_every_n_steps=100,
                        log_dir=f'runs/Pretraining_CIFAR10_{datetime.now().strftime("%d-%m-%Y_%H-%M")}')
    pre_trainer.train(train_loader_1, shuffle=True)

100%|██████████| 195/195 [01:04<00:00,  3.04it/s]
100%|██████████| 195/195 [00:49<00:00,  3.97it/s]
100%|██████████| 195/195 [00:48<00:00,  4.05it/s]
100%|██████████| 195/195 [00:48<00:00,  4.05it/s]
100%|██████████| 195/195 [00:49<00:00,  3.98it/s]
100%|██████████| 195/195 [00:50<00:00,  3.89it/s]
100%|██████████| 195/195 [00:49<00:00,  3.92it/s]
100%|██████████| 195/195 [00:49<00:00,  3.93it/s]
100%|██████████| 195/195 [00:49<00:00,  3.90it/s]
100%|██████████| 195/195 [00:49<00:00,  3.90it/s]
100%|██████████| 195/195 [00:50<00:00,  3.83it/s]
100%|██████████| 195/195 [00:50<00:00,  3.90it/s]
100%|██████████| 195/195 [00:50<00:00,  3.88it/s]
100%|██████████| 195/195 [00:50<00:00,  3.86it/s]
100%|██████████| 195/195 [00:49<00:00,  3.92it/s]
100%|██████████| 195/195 [00:50<00:00,  3.84it/s]
100%|██████████| 195/195 [00:51<00:00,  3.78it/s]
100%|██████████| 195/195 [00:47<00:00,  4.06it/s]
100%|██████████| 195/195 [00:48<00:00,  4.06it/s]
100%|██████████| 195/195 [00:47<00:00,  4.07it/s]


KeyboardInterrupt: 

In [8]:
torch.save(model.state_dict(), 'saved_models/pretraining/pretrained_cifar10.pth')