## Training of feature extractor (head layer)

In [1]:
from datetime import datetime

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from models.model import ConvModel
from models.head_trainer import HeadTrainer

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    gpu_index = 0
else:
    device = torch.device('cpu')
    gpu_index = -1

In [3]:
model = ConvModel(out_dim=10, dataset='CIFAR10')

model.load_state_dict(torch.load('saved_models/pretraining/pretrained_cifar10.pth'))



<All keys matched successfully>

In [4]:
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x1ed83300bf0>

In [5]:
head_dataset = datasets.CIFAR10('./datasets', download=True, train=True, transform=transforms.ToTensor())
head_loader = torch.utils.data.DataLoader(head_dataset, batch_size=256, shuffle=True, pin_memory=True, drop_last=True, generator=g)

Files already downloaded and verified


In [6]:
model.backbone.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=512, out_features=128, bias=True),
    torch.nn.ReLU(inplace=True),
    torch.nn.BatchNorm1d(128),
    torch.nn.Linear(in_features=128, out_features=10, bias=True),
)
model.backbone.requires_grad_(False)
model.backbone.fc.requires_grad_(True)

Sequential(
  (0): Linear(in_features=512, out_features=128, bias=True)
  (1): ReLU(inplace=True)
  (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Linear(in_features=128, out_features=10, bias=True)
)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), 0.0003, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(head_loader), eta_min=0,
                                                       last_epoch=-1)

In [8]:
with torch.cuda.device(gpu_index):
    simclr = HeadTrainer(model=model, optimizer=optimizer, scheduler=scheduler, device=device, epochs=400, log_every_n_steps=100,
                         log_dir=f'runs/Head_training_CIFAR10_{datetime.now().strftime("%d-%m-%Y_%H-%M")}')
    simclr.train(head_loader)

100%|██████████| 195/195 [00:20<00:00,  9.74it/s]
100%|██████████| 195/195 [00:06<00:00, 28.02it/s]
100%|██████████| 195/195 [00:06<00:00, 28.20it/s]
100%|██████████| 195/195 [00:06<00:00, 28.13it/s]
100%|██████████| 195/195 [00:06<00:00, 28.25it/s]
100%|██████████| 195/195 [00:06<00:00, 28.13it/s]
100%|██████████| 195/195 [00:07<00:00, 27.83it/s]
100%|██████████| 195/195 [00:07<00:00, 27.75it/s]
100%|██████████| 195/195 [00:06<00:00, 27.91it/s]
100%|██████████| 195/195 [00:06<00:00, 27.91it/s]
100%|██████████| 195/195 [00:07<00:00, 27.73it/s]
100%|██████████| 195/195 [00:06<00:00, 27.86it/s]
100%|██████████| 195/195 [00:07<00:00, 25.87it/s]
100%|██████████| 195/195 [00:07<00:00, 25.72it/s]
100%|██████████| 195/195 [00:07<00:00, 27.27it/s]
100%|██████████| 195/195 [00:07<00:00, 27.39it/s]
100%|██████████| 195/195 [00:07<00:00, 27.09it/s]
100%|██████████| 195/195 [00:07<00:00, 27.02it/s]
100%|██████████| 195/195 [00:07<00:00, 25.68it/s]
100%|██████████| 195/195 [00:07<00:00, 26.36it/s]


KeyboardInterrupt: 

In [9]:
torch.save(model.state_dict(), 'saved_models/head_training/head_cifar10.pth')