In [1]:
from model.model import ConvModel, SimCLR
import argparse
import torch
import torch.backends.cudnn as cudnn
from torchvision import models
from torchvision import transforms, datasets


In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.deterministic = True
    cudnn.benchmark = True
    gpu_index = 0
else:
    device = torch.device('cpu')
    gpu_index = -1


In [5]:
composed_transforms = transforms.Compose([
        transforms.GaussianBlur(kernel_size=3),
      transforms.RandomRotation(degrees=10), # TODO
      transforms.RandomGrayscale(p=0.2),
      transforms.ToTensor()])

In [6]:
train_dataset = datasets.MNIST('./datasets', train=True,
                               transform=composed_transforms,
                               download=True)
val_dataset = datasets.MNIST('./datasets', download=True, transform=transforms.ToTensor())
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x1f9266935f0>

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=False,
     pin_memory=True, drop_last=True, generator=g)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)

In [8]:
model = ConvModel(out_dim=10)

# model.load_state_dict(torch.load('saved_models/model_01_no_aug_random_labels.pth'))




In [9]:
optimizer = torch.optim.Adam(model.parameters(), 0.0003, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                       last_epoch=-1)

with torch.cuda.device(gpu_index):
    simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, device=device, epochs=100, log_every_n_steps=100)
    simclr.train(train_loader, shuffle=True)

100%|██████████| 234/234 [00:36<00:00,  6.41it/s]
100%|██████████| 234/234 [00:30<00:00,  7.71it/s]
100%|██████████| 234/234 [00:29<00:00,  7.84it/s]
100%|██████████| 234/234 [00:29<00:00,  7.87it/s]
100%|██████████| 234/234 [00:29<00:00,  7.94it/s]
100%|██████████| 234/234 [00:30<00:00,  7.79it/s]
100%|██████████| 234/234 [00:30<00:00,  7.67it/s]
100%|██████████| 234/234 [00:29<00:00,  7.88it/s]
100%|██████████| 234/234 [00:35<00:00,  6.66it/s]
100%|██████████| 234/234 [00:32<00:00,  7.11it/s]
100%|██████████| 234/234 [00:29<00:00,  7.92it/s]
100%|██████████| 234/234 [00:31<00:00,  7.52it/s]
100%|██████████| 234/234 [00:30<00:00,  7.72it/s]
100%|██████████| 234/234 [00:30<00:00,  7.74it/s]
100%|██████████| 234/234 [00:29<00:00,  7.86it/s]
100%|██████████| 234/234 [00:30<00:00,  7.70it/s]
100%|██████████| 234/234 [00:30<00:00,  7.77it/s]
100%|██████████| 234/234 [00:29<00:00,  7.82it/s]
100%|██████████| 234/234 [00:30<00:00,  7.65it/s]
100%|██████████| 234/234 [00:29<00:00,  7.82it/s]


In [10]:
torch.save(model.state_dict(), 'saved_models/model_final.pth')

In [17]:
model

ConvModel(
  (resnet_model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [12]:
model.backbone.fc = torch.nn.Sequential(
    torch.nn.Linear(in_features=512, out_features=128, bias=True),
    torch.nn.ReLU(inplace=True),
    # torch.nn.BatchNorm2d(128,)
    torch.nn.Linear(in_features=128, out_features=10, bias=True),
)
# liniowa 512 -> 128
# relu
# batchnorm
# liniowa 128 -> 10

In [13]:
model.backbone.requires_grad_(False)
model.backbone.fc.requires_grad_(True)
model

ConvModel(
  (resnet_model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [14]:
optimizer = torch.optim.Adam(model.parameters(), 0.0003, weight_decay=1e-4)

with torch.cuda.device(gpu_index):
    simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, device=device, epochs=100, log_every_n_steps=100)
    simclr.train(val_loader, shuffle=False)


100%|██████████| 235/235 [00:08<00:00, 28.24it/s]
100%|██████████| 235/235 [00:07<00:00, 29.66it/s]
100%|██████████| 235/235 [00:08<00:00, 29.20it/s]
100%|██████████| 235/235 [00:07<00:00, 30.29it/s]
100%|██████████| 235/235 [00:07<00:00, 29.57it/s]
100%|██████████| 235/235 [00:08<00:00, 28.86it/s]
100%|██████████| 235/235 [00:08<00:00, 29.28it/s]
100%|██████████| 235/235 [00:07<00:00, 29.86it/s]
100%|██████████| 235/235 [00:07<00:00, 29.80it/s]
100%|██████████| 235/235 [00:08<00:00, 29.11it/s]
100%|██████████| 235/235 [00:07<00:00, 29.39it/s]
100%|██████████| 235/235 [00:07<00:00, 30.44it/s]
100%|██████████| 235/235 [00:07<00:00, 30.39it/s]
100%|██████████| 235/235 [00:07<00:00, 29.85it/s]
100%|██████████| 235/235 [00:07<00:00, 30.40it/s]
100%|██████████| 235/235 [00:07<00:00, 30.27it/s]
100%|██████████| 235/235 [00:07<00:00, 30.42it/s]
100%|██████████| 235/235 [00:07<00:00, 29.76it/s]
100%|██████████| 235/235 [00:07<00:00, 30.06it/s]
100%|██████████| 235/235 [00:07<00:00, 30.30it/s]
