<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/SimCLR_by_Resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import and Data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler as lr_scheduler
import torch.utils as utils
from torch.utils.data import DataLoader, ConcatDataset

import torchvision
from torchvision import datasets, transforms

from tqdm import tqdm
import numpy as np

In [None]:
#ここでdata augumentation
class CutOut:
    def __init__(self, size=10):
        self.size = size
    def __call__(self, img):
        img = np.array(img)
        h = torch.randint(0, 32-self.size, (1,))
        w = torch.randint(0, 32-self.size, (1,))
        img[:, h:h + self.size, w:w + self.size] = 0
        img = transforms.ToPILImage()(img)
        return img

class ImgAugmentation:
    def __init__(self):
        cutout = CutOut()
        color_jitter = transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([cutout], p=0.5),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.ToTensor()
        ])
    def __call__(self, x):
        return self.transform(x), self.transform(x)

train = datasets.CIFAR10(root='./data', train=True, transform=ImgAugmentation(), download=True)
test = datasets.CIFAR10(root='./data', train=False, transform=ImgAugmentation(), download=True)
all = ConcatDataset([train, test])

all_loader = DataLoader(all, batch_size=256, shuffle=True, drop_last=True)#60,000data

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:18<00:00, 9137534.30it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
#STL10のsize=96

# Model

In [None]:
#resnet18 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
#resnet18.fc = nn.Identity()
# invalid size

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_c, out_c, down_sampling=False):
        super(BasicBlock, self).__init__()
        if down_sampling:
            self.stride = 2
            self.convds = nn.Conv2d(in_c, out_c, 2, 2)
        else:
            self.stride = 1
        self.down_sampling = down_sampling
        self.conv1 = nn.Conv2d(in_c, out_c, 3, self.stride, 1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        skip = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        if self.down_sampling:
            skip = self.convds(skip)
        else: pass
        x += skip
        x = self.relu(x)

        return x

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 2, 1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)
        self.block11 = BasicBlock(64, 64)
        self.block12 = BasicBlock(64, 64)
        self.block21 = BasicBlock(64, 128, True)
        self.block22 = BasicBlock(128, 128)
        self.block31 = BasicBlock(128, 256, True)
        self.block32 = BasicBlock(256, 256)
        self.block41 = BasicBlock(256, 512, True)
        self.block42 = BasicBlock(512, 512)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

    def forward(self, x):
        #1
        x = self.conv1(x) #[3, 96, 96] -> [64, 48, 48]
        x = relu(x)
        #2
        x = self.maxpool(x) #[64, 48, 48] -> [64, 24, 24]
        x = self.block11(x) #[64, 24, 24] -> [64, 24, 24]
        x = self.block12(x)
        #3
        x = self.block21(x) #[64, 24, 24] -> [128, 12, 12]
        x = self.block22(x)
        #4
        x = self.block31(x) #[128, 12, 12] -> [256, 6, 6]
        x = self.block32(x)
        #5
        x = self.block41(x) #[256, 6, 6] -> [512, 3, 3]
        x = self.block42(x)
        #6
        x = self.avgpool(x) #[512, 3, 3] -> [512, 1, 1]
        x = self.flatten(x)

        return x

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(4*4*128, 1024),
            nn.Linear(1024, 64)
        )

    def forward(self, xi, xj):#([b, c, h, w], [b, c, h, w]) -> ([b, dim], [b, dim])
        hi = self.encoder(xi)
        hj = self.encoder(xj)
        zi = self.projector(hi)
        zj = self.projector(hj)
        return zi, zj

# NT-Xent Loss

In [None]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size=256):
        super(NT_Xent, self).__init__()
        self.device = torch.device('cuda')
        self.batch_size = batch_size
        self.mask = self.make_mask(batch_size)
        self.similarity = nn.CosineSimilarity(dim=2)
        self.criterion = nn.CrossEntropyLoss(reduction='sum')

    def make_mask(self, batch_size):
        mask = torch.ones((2*batch_size, 2*batch_size))
        mask = mask.fill_diagonal_(0)
        ones = torch.ones((batch_size))
        mask = mask - torch.diag(ones, batch_size) - torch.diag(ones, -batch_size)
        return mask.bool()

    def forward(self, zi, zj):
        z = torch.cat((zi, zj), dim=0)#z = [2b, 2f]

        sim = 10 * self.similarity(z.unsqueeze(1), z.unsqueeze(0))#sim = [2b, 2b]
        sim_ij = torch.diag(sim, self.batch_size)#sim_ij = [b, 1]
        sim_ji = torch.diag(sim, -self.batch_size)#sim_ji = [b, 1]

        positive = torch.cat((sim_ij, sim_ji), dim=0).reshape(2*self.batch_size, 1)#positive = [2b, 1]
        negative = sim[self.mask].reshape(2*self.batch_size, -1)#negative = [2b, 2b-2]

        target = torch.zeros(2*self.batch_size, dtype=torch.long).to(self.device)#index=0が正解クラス
        logits = torch.cat((positive, negative), dim=1)#pred = [2b, 2b-1], index=0にpositiveそれ以外はnegative
        loss = self.criterion(logits, target)
        loss /= 2 * self.batch_size

        return loss

# Checkpoint

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, criterion):
    checkpoint = {
        'epoch': epoch,
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'scheduler_sd': scheduler.state_dict(),
        'criterion': criterion
    }
    torch.save(checkpoint, f'simclr_epoch{epoch}.pth')

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(filename)

    epoch = checkpoint['epoch']
    criterion = checkpoint['criterion']
    model.load_state_dict(checkpoint['model_sd'])
    optimizer.load_state_dict(checkpoint['optimizer_sd'])
    scheduler.load_state_dict(checkpoint['scheduler_sd'])

    return epoch, criterion

# Contrastive Learning

In [None]:
simclr = SimCLR(encoder)
optimizer = optim.Adam(simclr.parameters())
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=8, eta_min=1e-4)
nt_xent = NT_Xent()

In [None]:
epoch, loss = load_checkpoint('simclr_epochx.pth', simclr, optimizer, scheduler)

  checkpoint = torch.load(filename)


In [None]:
epochs = 8
start_epochs = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simclr.to(device)
nt_xent.to(device)

simclr.train()
for epoch in range(start_epochs, epochs):
    loss_epoch = 0
    for (xi, xj), _ in tqdm(all_loader):
        optimizer.zero_grad()

        xi, xj = xi.to(device), xj.to(device)
        zi, zj = simclr(xi, xj)
        zi, zj = zi.to(device), zj.to(device)

        loss = nt_xent(zi, zj)
        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
    scheduler.step()
    print(f'Epoch{epoch} : loss = {loss_epoch}')
#T4 GPU
#epoch=8 | 2min/epoch | avgloss(0) = 1.1277, agvloss(4) = 0.146, avgloss(6) = 0.1277, avgloss(7) = 0.1233

1it [00:02,  2.11s/it]

idx0 : 5.628911972045898


51it [00:26,  1.70it/s]

idx50 : 1.4876880645751953


101it [00:48,  2.48it/s]

idx100 : 0.667182981967926


151it [01:09,  2.47it/s]

idx150 : 0.5032867193222046


201it [01:32,  2.23it/s]

idx200 : 0.3775142431259155


234it [01:47,  2.17it/s]


Epoch0 : loss = 1.1277423371107151


1it [00:00,  2.43it/s]

idx0 : 0.3416915237903595


51it [00:24,  2.45it/s]

idx50 : 0.3053037226200104


101it [00:45,  2.24it/s]

idx100 : 0.2668510675430298


151it [01:09,  2.37it/s]

idx150 : 0.22591030597686768


201it [01:30,  2.02it/s]

idx200 : 0.27110421657562256


234it [01:45,  2.23it/s]


Epoch1 : loss = 0.2665347072303805


1it [00:00,  2.51it/s]

idx0 : 0.21391044557094574


51it [00:21,  2.44it/s]

idx50 : 0.2049366980791092


101it [00:42,  2.48it/s]

idx100 : 0.18282517790794373


151it [01:04,  2.47it/s]

idx150 : 0.179269939661026


201it [01:25,  2.26it/s]

idx200 : 0.18979308009147644


234it [01:39,  2.34it/s]


Epoch2 : loss = 0.19696918562946156


1it [00:00,  2.51it/s]

idx0 : 0.16481374204158783


51it [00:21,  2.07it/s]

idx50 : 0.16170160472393036


101it [00:42,  2.48it/s]

idx100 : 0.1504020392894745


151it [01:04,  2.47it/s]

idx150 : 0.1607716828584671


201it [01:25,  2.43it/s]

idx200 : 0.163396418094635


234it [01:39,  2.34it/s]


Epoch3 : loss = 0.16467901395681578


1it [00:00,  2.48it/s]

idx0 : 0.15396012365818024


51it [00:21,  2.41it/s]

idx50 : 0.13757817447185516


101it [00:42,  2.12it/s]

idx100 : 0.14650285243988037


151it [01:05,  2.43it/s]

idx150 : 0.15862494707107544


201it [01:26,  2.50it/s]

idx200 : 0.13836175203323364


234it [01:41,  2.31it/s]


Epoch4 : loss = 0.14653580820458567


1it [00:00,  2.43it/s]

idx0 : 0.13746285438537598


51it [00:22,  2.41it/s]

idx50 : 0.13520996272563934


101it [00:44,  2.46it/s]

idx100 : 0.12420349568128586


151it [01:06,  2.16it/s]

idx150 : 0.13044075667858124


201it [01:27,  2.42it/s]

idx200 : 0.12338107824325562


234it [01:41,  2.30it/s]


Epoch5 : loss = 0.1348522774493083


1it [00:00,  1.97it/s]

idx0 : 0.13929934799671173


51it [00:21,  2.48it/s]

idx50 : 0.12594427168369293


101it [00:42,  2.50it/s]

idx100 : 0.12747791409492493


151it [01:04,  2.43it/s]

idx150 : 0.12982724606990814


201it [01:25,  1.98it/s]

idx200 : 0.12980495393276215


234it [01:39,  2.35it/s]


Epoch6 : loss = 0.12770037072846013


1it [00:00,  2.45it/s]

idx0 : 0.1211642473936081


51it [00:21,  2.21it/s]

idx50 : 0.1251055896282196


101it [00:42,  2.51it/s]

idx100 : 0.12269064038991928


151it [01:03,  2.48it/s]

idx150 : 0.11967761069536209


201it [01:25,  2.26it/s]

idx200 : 0.11613039672374725


234it [01:39,  2.36it/s]

Epoch7 : loss = 0.12334921351100644





In [None]:
save_checkpoint(8, simclr, optimizer, scheduler, loss_epoch)

# Classifier

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

train_inf = datasets.CIFAR10('./data', train=True, transform=transform, download=True)
test_inf = datasets.CIFAR10('./data', train=False, transform=transform)

train_inf_loader = DataLoader(train_inf, batch_size=256, shuffle=True, drop_last=True)
test_inf_loader = DataLoader(test_inf, batch_size=256, shuffle=True, drop_last=True)

Files already downloaded and verified


In [None]:
classifier = nn.Sequential(
    nn.Linear(4*4*128, 10)
)

In [None]:
optimizer_inf = optim.Adam(classifier.parameters())
scheduler_inf = lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
criterion_inf = nn.CrossEntropyLoss()

epochs_inf = 4
start_epochs_inf = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder.to(device)
classifier.to(device)

N = len(train_inf_loader.dataset)
n = len(test_inf_loader.dataset)

encoder.eval()
for epoch in range(start_epochs_inf, epochs_inf):
    tr_loss = 0
    tr_corr = 0
    classifier.train()
    for data, target in tqdm(train_inf_loader):
        optimizer_inf.zero_grad()
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            out = encoder(data)
        out = classifier(out)
        loss = criterion_inf(out, target)
        loss.backward()
        optimizer_inf.step()

        tr_loss += loss.item()
        _, pred = out.max(1)
        tr_corr += (pred == target).sum().item()
    scheduler_inf.step()

    ts_loss = 0
    ts_corr = 0
    classifier.eval()
    for data, target in tqdm(test_inf_loader):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            out = encoder(data)
            out = classifier(out)
            loss = criterion_inf(out, target)
        ts_loss += loss.item()
        _, pred = out.max(1)
        ts_corr += (pred == target).sum().item()
    print(f'Epoch{epoch} : {tr_loss}, {tr_corr*100/N} | {ts_loss}, {ts_corr*100/n}')
#2min/epoch
#49%

100%|██████████| 195/195 [00:08<00:00, 22.08it/s]
100%|██████████| 39/39 [00:01<00:00, 23.84it/s]


Epoch0 : 310.1539113521576, 44.24 | 61.54429543018341, 45.17


100%|██████████| 195/195 [00:08<00:00, 22.79it/s]
100%|██████████| 39/39 [00:01<00:00, 20.10it/s]


Epoch1 : 297.32441210746765, 46.806 | 59.630680561065674, 46.54


100%|██████████| 195/195 [00:08<00:00, 23.75it/s]
100%|██████████| 39/39 [00:01<00:00, 19.95it/s]


Epoch2 : 289.3976249694824, 48.198 | 58.69546592235565, 47.95


100%|██████████| 195/195 [00:09<00:00, 21.25it/s]
100%|██████████| 39/39 [00:01<00:00, 24.33it/s]

Epoch3 : 284.1865530014038, 49.1 | 58.17367923259735, 48.34





In [None]:
torch.save(classifier.state_dict(), 'classifier4ep_on_8ep.pth')