<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/SimCLR2_by_SEW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip uninstall -y tensorflow
#!pip install tensorflow-cpu
#!pip install tensorflow
!pip install spikingjelly
!pip install wandb
#!pip install torch_xla

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler as lrs
from torch.utils.data import DataLoader, ConcatDataset
from torchvision.transforms import v2 as TF
from torchvision import datasets

import spikingjelly
from spikingjelly.activation_based import layer as jnn, neuron, functional as jF
'''
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from torch_xla.amp import autocast
import torch_xla.debug.metrics as met
'''
from tqdm import tqdm
import wandb

# Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, inplane, outplane, down_sampling=True):
        super(BasicBlock, self).__init__()
        self.down_sampling = down_sampling
        self.stride = 2 if down_sampling else 1
        self.down_sample  = nn.Sequential(
            jnn.Conv2d(inplane, outplane, 2, 2, bias=False),
            jnn.BatchNorm2d(outplane),
            neuron.IFNode()
        )

        layer = []
        layer.append(jnn.Conv2d(inplane, outplane, 3, self.stride, 1, 1, 2, bias=False))
        layer.append(jnn.BatchNorm2d(outplane))
        layer.append(neuron.IFNode())
        layer.append(jnn.Conv2d(outplane, outplane, 3, 1, 1, 1, 2, bias=False))
        layer.append(jnn.BatchNorm2d(outplane))
        layer.append(neuron.IFNode())
        self.layer = nn.Sequential(*layer)

    def forward(self, x):
        identity = x
        x = self.layer(x)
        if self.down_sampling:
            identity = self.down_sample(identity)
        x += identity
        return x

class SEW_ResNet(nn.Module):
    def __init__(self, T=4):
        super(SEW_ResNet, self).__init__()
        self.T = T

        self.first = nn.Sequential(
            jnn.Conv2d(3, 32, 3, 1, 1, bias=False),
            jnn.BatchNorm2d(32),
            neuron.IFNode()
        )
        self.block1 = BasicBlock(32, 32, False)
        self.block2 = BasicBlock(32, 32, False)
        self.block3 = BasicBlock(32, 64, True)
        self.block4 = BasicBlock(64, 64, False)
        self.block5 = BasicBlock(64, 128, True)
        self.block6 = BasicBlock(128, 128, False)
        self.last = nn.Sequential(
            jnn.AdaptiveAvgPool2d((1, 1)),
            jnn.Flatten()
        )

        jF.set_step_mode(self, 'm')

    def forward(self, x):
        jF.reset_net(self)
        x = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        x = self.first(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.last(x)
        return x

In [None]:
class Projector(nn.Module):
    def __init__(self, indim=128, outdim=64):
        super(Projector, self).__init__()
        self.projector = nn.Sequential(
            jnn.Linear(indim, indim, bias=False),
            neuron.IFNode(),
            jnn.Linear(indim, outdim, bias=False)
        )
        '''
        layer0 = self.projector[0]
        nn.init.normal_(layer0.weight, 1/128, 1/128**0.5)
        '''
        jF.set_step_mode(self, 'm')

    def forward(self, h):
        jF.reset_net(self)
        z = self.projector(h)
        return z

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projector):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = projector

    def forward(self, x1, x2):
        h1, h2 = self.encoder(x1), self.encoder(x2)
        z1, z2 = self.projector(h1), self.projector(h2)
        return z1.mean(0), z2.mean(0)

In [None]:
class Classifier(nn.Module):
    def __init__(self, indim, classes):
        super(Classifier, self).__init__()
        self.layer = nn.Sequential(
            jnn.Linear(indim, classes, bias=False),
            neuron.IFNode()
        )
        jF.set_step_mode(self, 'm')
    def forward(self, x):
        jF.reset_net(self)
        y = self.layer(x)
        return y.mean(0)

In [None]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size, tau, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.tau = tau
        self.device = device
        self.mask = self.make_mask()
        self.cosine = nn.CosineSimilarity(dim=2)
        self.Xent = nn.CrossEntropyLoss()

    def make_mask(self):
        mask = torch.eye(2 * self.batch_size, dtype=torch.bool)
        for i in range(self.batch_size):
            mask[i, i+self.batch_size] = 1
            mask[i+self.batch_size, i] = 1
        return ~mask

    def forward(self, z1, z2):
        z = torch.cat((z1, z2), dim=0).to(self.device)
        similarity = self.cosine(z.unsqueeze(1), z.unsqueeze(0)) / self.tau

        sim_ij = similarity[range(self.batch_size), range(self.batch_size, 2 * self.batch_size)]
        sim_ji = similarity[range(self.batch_size, 2 * self.batch_size), range(self.batch_size)]

        positive = torch.cat([sim_ij, sim_ji], dim=0).reshape(2*self.batch_size, 1)
        negative = similarity[self.mask].reshape(2*self.batch_size, -1)

        labels = torch.zeros(2*self.batch_size, dtype=torch.long).to(self.device)
        logits = torch.cat((positive, negative), dim=1)
        loss = self.Xent(logits, labels)

        return loss / 2

# Utils

In [None]:
class DataAugmentation:
    def __init__(self):
        color_jitter = TF.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.tf = TF.Compose([
            TF.RandomResizedCrop(32, (0.36, 1)),
            TF.RandomHorizontalFlip(p=0.5),
            TF.RandomApply([color_jitter], p=0.8),
            TF.RandomGrayscale(p=0.2),
            TF.ToImage(),
            TF.ToDtype(torch.float32, scale=True)
        ])

    def __call__(self, x):
        return self.tf(x), self.tf(x)

In [None]:
def get_loader(data='cifar10', split='train', batch_size=128, DA=False):
    tf = DataAugmentation() if DA else TF.Compose([TF.ToImage(), TF.ToDtype(torch.float32, scale=True)])
    if data == 'cifar10':
        match split:
            case 'train':
                data = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
            case 'test':
                data = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
            case 'all':
                train = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
                test = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
                data = ConcatDataset([train, test])
    elif data == 'stl10':
        match split:
            case 'train':
                data = datasets.STL10('./data', split='train', transform=tf, download=True)
            case 'test':
                data = datasets.STL10('./data', split='test', transform=tf, download=True)
            case 'all':
                data = datasets.STL10('./data', split='unlabeled', transform=tf, download=True)
    else:
        print(f'{data} is not supported >_<. cifar10 or stl10 is supported')
    loader = DataLoader(data, batch_size, shuffle=True, drop_last=True, num_workers=2)
    return loader

In [None]:
def train_(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    correct = 0
    model.train()
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

In [None]:
def train_onTPU(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    correct = 0
    model.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            out = model(data)
            loss = criterion(out, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

In [None]:
def train_cl(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    model.train()
    for (x1, x2), _ in tqdm(loader):
        x1, x2 = x1.to(device), x2.to(device)
        optimizer.zero_grad()
        z1, z2 = model(x1, x2)
        loss = criterion(z1, z2)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    scheduler.step()
    return running_loss

In [None]:
def train_cl_onTPU(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    model.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for (x1, x2), _ in tqdm(loader):
        x1, x2 = x1.to(device), x2.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            z1, z2 = model(x1, x2)
            loss = criterion(z1, z2)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        running_loss += loss.item()
    scheduler.step()
    return running_loss

In [None]:
def train_in_cl(loader, encoder, classifier, optimizer, criterion, device):
    correct = 0
    encoder.eval()
    classifier.train()
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            z = encoder(data)
        out = classifier(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        correct += (out.argmax(1) == target).sum().item()
    return correct

In [None]:
def train_in_cl_onTPU(loader, encoder, classifier, optimizer, criterion, device):
    correct = 0
    encoder.eval()
    classifier.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            with torch.no_grad():
                z = encoder(data)
            out = classifier(z)
            loss = criterion(out, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        correct += (out.argmax(1) == target).sum().item()
    return correct

In [None]:
def save_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = {
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'scheduler_sd': scheduler.state_dict()
    }
    torch.save(checkpoint, f'{filename}.pth')

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(f'{filename}.pth')
    model.load_state_dict(checkpoint['model_sd'])
    optimizer.load_state_dict(checkpoint['optimizer_sd'])
    scheduler.load_state_dict(checkpoint['scheduler_sd'])

# Main

In [None]:
#instance
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch_xla.device()
#------------------------#
loader = get_loader('cifar10', split='all', batch_size=256, DA=True)
loader2 = get_loader('cifar10', split='test', batch_size=128, DA=False)
N = len(loader2.dataset)
#------------------------#
encoder = SEW_ResNet(4)
projector = Projector()
model = SimCLR(encoder, projector).to(device)
classifier = Classifier(128, 10).to(device)
#------------------------#
optimizer = optim.SGD(model.parameters(), lr=0.3)
optimizer2 = optim.Adam(classifier.parameters())
scheduler1 = lrs.LinearLR(optimizer, start_factor=0.01, total_iters=8)
scheduler2 = lrs.CosineAnnealingLR(optimizer, T_max=4, eta_min=1e-1)
scheduler = lrs.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[8])
#------------------------#
criterion = NT_Xent(256, 0.2, device).to(device)
criterion2 = nn.CrossEntropyLoss()

wandb.login()
run = wandb.init(
    project = 'SimCLR SEW 1102',
    config = {
        'Architecture': 'SEW-ResNet14(dim=128)',
        'feature dim': 128,
        'embedding dim': 64,
        'T': 4,
        'optim': 'SGD',
        'lr': 0.3,
        'sche1': 'Linear(0.01, 8)',
        'sche2': 'Cosine(8, 1e-1)',
        'sche': 'Seq([8])',
        'criterion': 'NT-Xent',
        'tau': 0.2,
        'Data': 'Cifar10',
        'batch': 256,
        'else': 'groups=2, down_sample.stride=2, ADD'
    }
)

#train
start_epoch = 0
epochs = 16
for epoch in range(start_epoch, epochs):
    loss = train_cl(loader, model, optimizer, scheduler2, criterion, device)
    correct = train_in_cl(loader2, model.encoder, classifier, optimizer2, criterion2, device)
    wandb.log({'loss': loss, 'acc': correct*100/N})
    print(f'Epoch: {epoch} | loss: {loss} | acc: {correct*100/N}%')

wandb.finish()
save_checkpoint('SimCLR_by_SEW_1107_ADD', model, optimizer, scheduler2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:06<00:00, 27.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtk-0311-h[0m ([33mtk-0311-h-hosei-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 234/234 [03:09<00:00,  1.24it/s]


Epoch: 0 | loss: 721.524603843689 | acc: 13.08%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 1 | loss: 716.3026320934296 | acc: 16.23%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 2 | loss: 698.2833225727081 | acc: 18.45%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 3 | loss: 673.5129079818726 | acc: 20.73%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 4 | loss: 627.9863755702972 | acc: 21.83%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 5 | loss: 570.5778880119324 | acc: 26.18%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 6 | loss: 536.1793036460876 | acc: 27.22%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 7 | loss: 517.0463840961456 | acc: 27.33%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 8 | loss: 496.62525153160095 | acc: 27.99%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 9 | loss: 473.5976436138153 | acc: 30.08%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 10 | loss: 450.1544636487961 | acc: 31.67%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 11 | loss: 436.1648129224777 | acc: 33.3%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 12 | loss: 428.8815670013428 | acc: 33.57%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 13 | loss: 424.7932713031769 | acc: 34.41%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 14 | loss: 424.40572488307953 | acc: 34.89%


100%|██████████| 234/234 [03:07<00:00,  1.25it/s]


Epoch: 15 | loss: 425.8596637248993 | acc: 33.69%


0,1
acc,▁▂▃▃▄▅▆▆▆▆▇▇████
loss,██▇▇▆▄▄▃▃▂▂▁▁▁▁▁

0,1
acc,33.69
loss,425.85966


In [None]:
# Result
# T4 | 13.2 / 15 GB | batch=512, T=2 | 2m / epoch | 1epで10%, 2epで14%
# L4 | 28.9 / 22 GB | same | 1epで14%
# TPU | diagでエラー | positiveのdiagをno_grad()の中に だめ -> loaderをparallel_loaderへ(T=1) 1.5min/epochできた
# TPU | 上のままTを上げていく T=2でダメ -> mixed precision で T=2 OK, 2min/epoch
# 上の記録 | 13.23 -> 16.35 -> 18.44 -> 19.58 -> 19.96 -> 21.24 > 21.53 > 21.8

# down_sampleのstride=2, groups=2でやろう. paramsが1200k->500k
# T4 GPU | batch=256, T=3(9.8GB) OK(13%) | T=4(12.7GB) できる(13%) | T=5は無理そう
# TPU | batch=256, T=1 OK(10.4%) | T=2 OK(12.2%) | T=3 OK(13.2%) | T=4 NO(Exhausted)
# ひとまず TPU N=256, T=3, epochs=32, 20%でsaturation
# projectorをANNに変えたらResorce Exhausted
# L4 GPU | N=256 T=4 epochs=16 | 3min/epoch -> epoch5 17%でsaturation | z1, z2は異常なし

# 11/7 z1,z2問題解決
# hyper parameterは上と同じ | L4 GPUで3.1min/epoch | RAMは半分 -> まだいけるぞ |
# 34%まではうまくいった -> lrが小さすぎて坂を上れなかったか -> T-wiseでlossとる | どうであれもっとepoch増やしてから分析

# Inspection

In [None]:
l = get_loader('cifar10', split='all', batch_size=128, DA=True)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
#Analysis
device = torch.device('cpu')
e = SEW_ResNet(4)
m = SimCLR(e, 128, 64)
cl = Classifier(128, 10)
cr = NT_Xent(128, 0.2, device)
#load state dict
cp = torch.load('SimCLR_by_SEW_1103_failure.pth', weights_only=True, map_location=device)
m.load_state_dict(cp['model_sd'])

(d1, d2), t = next(iter(l))
z1, z2 = m(d1, d2)
print(z1[0])
print(z2[0])
loss = cr(z1, z2)
print(loss)

block2:tensor([[[1., 1., 1., 1.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]],

        [[1., 1., 1., 1.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.],
         [2., 2., 2., 2.]]], grad_fn=<SliceBackward0>)
block4:tensor([[[2., 1., 1., 1.],
         [2., 1., 1., 1.],
         [2., 1., 0., 1.],
         [2., 1., 1., 1.]],

        [[2., 1., 1., 1.],
         [2., 1., 1., 1.],
         [2., 1., 0., 0.],
         [2., 2., 1., 1.]],

        [[2., 2., 1., 1.],
         [2., 1., 1., 1.],
         [2., 1., 0., 0.],
         [2., 2., 1., 1.]],

        [[2., 1., 1., 1.],
         [2., 1., 1., 1.],
         [2., 1., 0., 0.],
         [2., 2., 1., 1.]]], grad_fn=<SliceBackward0>)
block6:tensor([[[3., 3., 3., 3.],
         [3., 3., 3.