<a href="https://colab.research.google.com/github/Kazuto-Takahashi/Research/blob/main/SimCLR_by_Resnet_ipynb_%E3%81%AE%E3%82%B3%E3%83%94%E3%83%BC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import and Data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler as lr_scheduler
import torch.utils as utils
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms

from tqdm import tqdm
import numpy as np

In [None]:
#ここでdata augumentation
class CutOut:
    def __init__(self, size=24):
        self.size = size
    def __call__(self, img):
        img = np.array(img)
        h = torch.randint(0, 96-self.size, (1,))
        w = torch.randint(0, 96-self.size, (1,))
        img[:, h:h + self.size, w:w + self.size] = 0
        img = transforms.ToPILImage()(img)
        return img

class ImgAugmentation:
    def __init__(self):
        cutout = CutOut()
        color_jitter = transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([cutout], p=0.5),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.ToTensor()
        ])
    def __call__(self, x):
        return self.transform(x), self.transform(x)

#train = datasets.CIFAR10(root='./data', train=True, transform=ImgAugmentation(), download=True)
#test = datasets.CIFAR10(root='./data', train=False, transform=ImgAugmentation(), download=True)
#all = ConcatDataset([train, test])
#all_loader = DataLoader(all, batch_size=512, shuffle=True, drop_last=True)#60,000data

dataset = datasets.STL10('./data', split='unlabeled', transform=ImgAugmentation(), download=True)
data_loader = DataLoader(dataset, batch_size=512, shuffle=True, drop_last=True)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [05:21<00:00, 8211155.78it/s]


Extracting ./data/stl10_binary.tar.gz to ./data


# Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_c, out_c, down_sampling=False):
        super(BasicBlock, self).__init__()
        if down_sampling:
            self.stride = 2
            self.convds = nn.Conv2d(in_c, out_c, 2, 2)
        else:
            self.stride = 1
        self.down_sampling = down_sampling
        self.conv1 = nn.Conv2d(in_c, out_c, 3, self.stride, 1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1)
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        skip = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        if self.down_sampling:
            skip = self.convds(skip)
        else: pass
        x += skip
        x = self.relu(x)

        return x

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 2, 1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2)
        self.block11 = BasicBlock(64, 64)
        self.block12 = BasicBlock(64, 64)
        self.block21 = BasicBlock(64, 128, True)
        self.block22 = BasicBlock(128, 128)
        self.block31 = BasicBlock(128, 256, True)
        self.block32 = BasicBlock(256, 256)
        self.block41 = BasicBlock(256, 512, True)
        self.block42 = BasicBlock(512, 512)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

    def forward(self, x):
        #1
        x = self.conv1(x) #[3, 96, 96] -> [64, 48, 48]
        x = self.relu(x)
        #2
        x = self.maxpool(x) #[64, 48, 48] -> [64, 24, 24]
        x = self.block11(x) #[64, 24, 24] -> [64, 24, 24]
        x = self.block12(x)
        #3
        x = self.block21(x) #[64, 24, 24] -> [128, 12, 12]
        x = self.block22(x)
        #4
        x = self.block31(x) #[128, 12, 12] -> [256, 6, 6]
        x = self.block32(x)
        #5
        x = self.block41(x) #[256, 6, 6] -> [512, 3, 3]
        x = self.block42(x)
        #6
        x = self.avgpool(x) #[512, 3, 3] -> [512, 1, 1]
        x = self.flatten(x)

        return x

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.Linear(512, 64)
        )

    def forward(self, xi, xj):#([b, c, h, w], [b, c, h, w]) -> ([b, dim], [b, dim])
        hi = self.encoder(xi)
        hj = self.encoder(xj)
        zi = self.projector(hi)
        zj = self.projector(hj)
        return zi, zj

# NT-Xent Loss

In [None]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size=512):
        super(NT_Xent, self).__init__()
        self.device = torch.device('cuda')
        self.batch_size = batch_size
        self.mask = self.make_mask(batch_size)
        self.similarity = nn.CosineSimilarity(dim=2)
        self.criterion = nn.CrossEntropyLoss(reduction='sum')

    def make_mask(self, batch_size):
        mask = torch.ones((2*batch_size, 2*batch_size))
        mask = mask.fill_diagonal_(0)
        ones = torch.ones((batch_size))
        mask = mask - torch.diag(ones, batch_size) - torch.diag(ones, -batch_size)
        return mask.bool()

    def forward(self, zi, zj):
        z = torch.cat((zi, zj), dim=0)#z = [2b, 2f]

        sim = 10 * self.similarity(z.unsqueeze(1), z.unsqueeze(0))#sim = [2b, 2b]
        sim_ij = torch.diag(sim, self.batch_size)#sim_ij = [b, 1]
        sim_ji = torch.diag(sim, -self.batch_size)#sim_ji = [b, 1]

        positive = torch.cat((sim_ij, sim_ji), dim=0).reshape(2*self.batch_size, 1)#positive = [2b, 1]
        negative = sim[self.mask].reshape(2*self.batch_size, -1)#negative = [2b, 2b-2]

        target = torch.zeros(2*self.batch_size, dtype=torch.long).to(self.device)#index=0が正解クラス
        logits = torch.cat((positive, negative), dim=1)#pred = [2b, 2b-1], index=0にpositiveそれ以外はnegative
        loss = self.criterion(logits, target)
        loss /= 2 * self.batch_size

        return loss

# Checkpoint

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, criterion):
    checkpoint = {
        'epoch': epoch,
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'scheduler_sd': scheduler.state_dict(),
        'criterion': criterion
    }
    torch.save(checkpoint, f'simclr_epoch{epoch}_resnet.pth')

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(filename)

    epoch = checkpoint['epoch']
    criterion = checkpoint['criterion']
    model.load_state_dict(checkpoint['model_sd'])
    optimizer.load_state_dict(checkpoint['optimizer_sd'])
    scheduler.load_state_dict(checkpoint['scheduler_sd'])

    return epoch, criterion

# Contrastive Learning

In [None]:
encoder = ResNet18()
simclr = SimCLR(encoder)
optimizer = optim.Adam(simclr.parameters())
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=8, eta_min=1e-5)
nt_xent = NT_Xent(512)

In [None]:
epoch, loss = load_checkpoint('simclr_epoch10_resnet.pth', simclr, optimizer, scheduler)

  checkpoint = torch.load(filename)


In [None]:
epochs = 10
start_epochs = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simclr.to(device)
nt_xent.to(device)

simclr.train()
for epoch in range(start_epochs, epochs):
    loss_epoch = 0
    for (xi, xj), _ in tqdm(data_loader):
        optimizer.zero_grad()

        xi, xj = xi.to(device), xj.to(device)
        zi, zj = simclr(xi, xj)
        zi, zj = zi.to(device), zj.to(device)

        loss = nt_xent(zi, zj)
        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
    scheduler.step()
    print(f'Epoch{epoch} : loss = {loss_epoch}')
#T4 GPU
#VGG-5 | epoch=8 | 2min/epoch | loss(0) = 1.1277, loss(4) = 0.146, loss(6) = 0.1277, loss(7) = 0.1233
#ResNet18 | epoch=10 | 8min/epoch | loss(0) = 821(/195=4.21), loss(2) = 1.584, loss(4) = 0.7967, loss(6) = 0.541, loss(9) = 0.492

100%|██████████| 195/195 [08:03<00:00,  2.48s/it]


Epoch0 : loss = 821.246425151825


100%|██████████| 195/195 [08:02<00:00,  2.48s/it]


Epoch1 : loss = 378.4808655977249


100%|██████████| 195/195 [07:58<00:00,  2.45s/it]


Epoch2 : loss = 308.9420028924942


100%|██████████| 195/195 [07:59<00:00,  2.46s/it]


Epoch3 : loss = 207.56589543819427


100%|██████████| 195/195 [07:58<00:00,  2.45s/it]


Epoch4 : loss = 155.36485081911087


100%|██████████| 195/195 [08:00<00:00,  2.47s/it]


Epoch5 : loss = 121.89358973503113


100%|██████████| 195/195 [07:55<00:00,  2.44s/it]


Epoch6 : loss = 105.62774163484573


100%|██████████| 195/195 [07:59<00:00,  2.46s/it]


Epoch7 : loss = 98.48255547881126


100%|██████████| 195/195 [07:56<00:00,  2.45s/it]


Epoch8 : loss = 96.89931404590607


100%|██████████| 195/195 [07:57<00:00,  2.45s/it]

Epoch9 : loss = 96.07199031114578





In [None]:
save_checkpoint(10, simclr, optimizer, scheduler, loss_epoch)

# Classifier

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

train_inf = datasets.STL10('./data', split='train', transform=transform, download=True)
test_inf = datasets.STL10('./data', split='test', transform=transform)

train_inf_loader = DataLoader(train_inf, batch_size=128, shuffle=True, drop_last=True)
test_inf_loader = DataLoader(test_inf, batch_size=128, shuffle=True, drop_last=True)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [06:07<00:00, 7188133.48it/s] 


Extracting ./data/stl10_binary.tar.gz to ./data


In [None]:
classifier = nn.Sequential(
    nn.Linear(512, 10)
)

In [None]:
optimizer_inf = optim.Adam(classifier.parameters())
scheduler_inf = lr_scheduler.CosineAnnealingLR(optimizer, T_max=4, eta_min=1e-4)
criterion_inf = nn.CrossEntropyLoss()

epochs_inf = 4
start_epochs_inf = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder.to(device)
classifier.to(device)

N = len(train_inf_loader.dataset)
n = len(test_inf_loader.dataset)

encoder.eval()
for epoch in range(start_epochs_inf, epochs_inf):
    tr_loss = 0
    tr_corr = 0
    classifier.train()
    for data, target in tqdm(train_inf_loader):
        optimizer_inf.zero_grad()
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            out = encoder(data)
        out = classifier(out)
        loss = criterion_inf(out, target)
        loss.backward()
        optimizer_inf.step()

        tr_loss += loss.item()
        _, pred = out.max(1)
        tr_corr += (pred == target).sum().item()
    scheduler_inf.step()

    ts_loss = 0
    ts_corr = 0
    classifier.eval()
    for data, target in tqdm(test_inf_loader):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            out = encoder(data)
            out = classifier(out)
            loss = criterion_inf(out, target)
        ts_loss += loss.item()
        _, pred = out.max(1)
        ts_corr += (pred == target).sum().item()
    print(f'Epoch{epoch} : {tr_loss}, {tr_corr*100/N} | {ts_loss}, {ts_corr*100/n}')
#2min/epoch
#49%
# linear on resnet_10 = 26%, -> encoder512dimじゃだめか -> Cifar10だからダメ -> STL10で30%!?

100%|██████████| 39/39 [00:05<00:00,  7.59it/s]
100%|██████████| 62/62 [00:05<00:00, 11.52it/s]


Epoch0 : 117.35667157173157, 13.24 | 139.58846187591553, 20.4375


100%|██████████| 39/39 [00:03<00:00, 11.48it/s]
100%|██████████| 62/62 [00:05<00:00, 12.04it/s]


Epoch1 : 81.9593403339386, 25.3 | 123.93009781837463, 27.9375


100%|██████████| 39/39 [00:03<00:00, 11.56it/s]
100%|██████████| 62/62 [00:05<00:00, 11.56it/s]


Epoch2 : 76.94475972652435, 28.88 | 121.93325006961823, 28.775


100%|██████████| 39/39 [00:03<00:00, 12.51it/s]
100%|██████████| 62/62 [00:05<00:00, 10.97it/s]

Epoch3 : 74.58009779453278, 30.52 | 117.1057003736496, 31.0875





In [None]:
torch.save(classifier.state_dict(), 'classifier4ep_on_8ep.pth')