<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/SimCLR2_by_SEW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip uninstall -y tensorflow
#!pip install tensorflow-cpu
#!pip install tensorflow
!pip install spikingjelly
#!pip install wandb
#!pip install torch_xla

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/437.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m430.1/437.6 kB[0m [31m18.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler as lrs
from torch.utils.data import DataLoader, ConcatDataset
from torchvision.transforms import v2 as TF
from torchvision import datasets

import spikingjelly
from spikingjelly.activation_based import layer as jnn, neuron, functional as jF

from tqdm import tqdm
import wandb

# Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, inplane, outplane, down_sampling=True):
        super(BasicBlock, self).__init__()
        self.down_sampling = down_sampling
        self.stride = 2 if down_sampling else 1
        self.down_sample  = nn.Sequential(
            jnn.Conv2d(inplane, outplane, 2, 2, bias=False),
            jnn.BatchNorm2d(outplane),
            neuron.IFNode()
        )

        layer = []
        layer.append(jnn.Conv2d(inplane, outplane, 3, self.stride, 1, 1, 2, bias=False))
        layer.append(jnn.BatchNorm2d(outplane))
        layer.append(neuron.IFNode())
        layer.append(jnn.Conv2d(outplane, outplane, 3, 1, 1, 1, 2, bias=False))
        layer.append(jnn.BatchNorm2d(outplane))
        layer.append(neuron.IFNode())
        self.layer = nn.Sequential(*layer)

    def forward(self, x):
        identity = x
        x = self.layer(x)
        if self.down_sampling:
            identity = self.down_sample(identity)
        x += identity
        return x

class SEW_ResNet(nn.Module):
    def __init__(self, T=4):
        super(SEW_ResNet, self).__init__()
        self.T = T

        self.first = nn.Sequential(
            jnn.Conv2d(3, 32, 3, 1, 1, bias=False),
            jnn.BatchNorm2d(32),
            neuron.IFNode()
        )
        self.block1 = BasicBlock(32, 32, False)
        self.block2 = BasicBlock(32, 32, False)
        self.block3 = BasicBlock(32, 64, True)
        self.block4 = BasicBlock(64, 64, False)
        self.block5 = BasicBlock(64, 128, True)
        self.block6 = BasicBlock(128, 128, False)
        self.last = nn.Sequential(
            jnn.AdaptiveAvgPool2d((1, 1)),
            jnn.Flatten()
        )

        jF.set_step_mode(self, 'm')

    def forward(self, x):
        jF.reset_net(self)
        x = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        x = self.first(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        #print(f'block6:{x[:,0,0,:4,:4]}')
        x = self.last(x)
        #print(f'last:{x[:,0,:8]}')
        return x

In [None]:
class Projector(nn.Module):
    def __init__(self, indim=128, outdim=64):
        super(Projector, self).__init__()
        self.projector = nn.Sequential(
            nn.Linear(indim, outdim, bias=False)
        )
        #layer0 = self.projector[0]
        #nn.init.normal_(layer0.weight, 1/128, 1/128**0.5)
        #jF.set_step_mode(self, 'm')

    def forward(self, h):
        #jF.reset_net(self)
        z = self.projector(h)
        #print(f'z:{z[:,0,:8]}')
        return z

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder, projector):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = projector

    def forward(self, x1, x2):
        h1, h2 = self.encoder(x1), self.encoder(x2)
        h1, h2 = h1.mean(0), h2.mean(0)
        z1, z2 = self.projector(h1), self.projector(h2)
        return z1, z2

In [None]:
class Classifier(nn.Module):
    def __init__(self, indim, classes):
        super(Classifier, self).__init__()
        self.layer = nn.Sequential(
            jnn.Linear(indim, classes, bias=False),
            neuron.IFNode()
        )
        jF.set_step_mode(self, 'm')
    def forward(self, x):
        jF.reset_net(self)
        y = self.layer(x)
        return y.mean(0)

In [None]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size, tau, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.tau = tau
        self.device = device
        self.mask = self.make_mask()
        self.cosine = nn.CosineSimilarity(dim=2)
        self.Xent = nn.CrossEntropyLoss()

    def make_mask(self):
        mask = torch.eye(2 * self.batch_size, dtype=torch.bool)
        ones = torch.ones(self.batch_size, dtype=torch.bool)
        mask += torch.diag(ones, self.batch_size) + torch.diag(ones, -self.batch_size)
        return ~mask

    def forward(self, z1, z2):
        z = torch.cat((z1, z2), dim=0).to(self.device)
        similarity = self.cosine(z.unsqueeze(1), z.unsqueeze(0)) / self.tau

        sim_ij = similarity[range(self.batch_size), range(self.batch_size, 2 * self.batch_size)]
        sim_ji = similarity[range(self.batch_size, 2 * self.batch_size), range(self.batch_size)]

        positive = torch.cat([sim_ij, sim_ji], dim=0).reshape(2*self.batch_size, 1)
        negative = similarity[self.mask].reshape(2*self.batch_size, -1)

        labels = torch.zeros(2*self.batch_size, dtype=torch.long).to(self.device)
        logits = torch.cat((positive, negative), dim=1)
        loss = self.Xent(logits, labels)

        return loss / 2

# Utils

In [None]:
class DataAugmentation:
    def __init__(self):
        color_jitter = TF.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.tf = TF.Compose([
            TF.RandomResizedCrop(32, (0.36, 1)),
            TF.RandomHorizontalFlip(p=0.5),
            TF.RandomApply([color_jitter], p=0.8),
            TF.RandomGrayscale(p=0.2),
            TF.ToImage(),
            TF.ToDtype(torch.float32, scale=True)
        ])

    def __call__(self, x):
        return self.tf(x), self.tf(x)

In [None]:
def get_loader(data='cifar10', split='train', batch_size=128, DA=False):
    tf = DataAugmentation() if DA else TF.Compose([TF.ToImage(), TF.ToDtype(torch.float32, scale=True)])
    if data == 'cifar10':
        match split:
            case 'train':
                data = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
            case 'test':
                data = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
            case 'all':
                train = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
                test = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
                data = ConcatDataset([train, test])
    elif data == 'stl10':
        match split:
            case 'train':
                data = datasets.STL10('./data', split='train', transform=tf, download=True)
            case 'test':
                data = datasets.STL10('./data', split='test', transform=tf, download=True)
            case 'all':
                data = datasets.STL10('./data', split='unlabeled', transform=tf, download=True)
    else:
        print(f'{data} is not supported >_<. cifar10 or stl10 is supported')
    loader = DataLoader(data, batch_size, shuffle=True, drop_last=True, num_workers=2)
    return loader

In [None]:
def train_(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    correct = 0
    model.train()
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

In [None]:
def train_onTPU(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    correct = 0
    model.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            out = model(data)
            loss = criterion(out, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

In [None]:
def train_cl(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    model.train()
    for (x1, x2), _ in tqdm(loader):
        x1, x2 = x1.to(device), x2.to(device)
        optimizer.zero_grad()
        z1, z2 = model(x1, x2)
        loss = criterion(z1, z2)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    scheduler.step()
    return running_loss

In [None]:
def train_cl_onTPU(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    model.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for (x1, x2), _ in tqdm(loader):
        x1, x2 = x1.to(device), x2.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            z1, z2 = model(x1, x2)
            loss = criterion(z1, z2)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        running_loss += loss.item()
    scheduler.step()
    return running_loss

In [None]:
def train_in_cl(loader, encoder, classifier, optimizer, criterion, device):
    correct = 0
    encoder.eval()
    classifier.train()
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            z = encoder(data)
        out = classifier(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        correct += (out.argmax(1) == target).sum().item()
    return correct

In [None]:
def train_in_cl_onTPU(loader, encoder, classifier, optimizer, criterion, device):
    correct = 0
    encoder.eval()
    classifier.train()
    loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with autocast(xm.xla_device()):
            with torch.no_grad():
                z = encoder(data)
            out = classifier(z)
            loss = criterion(out, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        xm.mark_step()

        correct += (out.argmax(1) == target).sum().item()
    return correct

In [None]:
def save_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = {
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'scheduler_sd': scheduler.state_dict()
    }
    torch.save(checkpoint, f'{filename}.pth')

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(f'{filename}.pth')
    model.load_state_dict(checkpoint['model_sd'])
    optimizer.load_state_dict(checkpoint['optimizer_sd'])
    scheduler.load_state_dict(checkpoint['scheduler_sd'])

# Main

In [None]:
#instance
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch_xla.device()
#------------------------#
loader = get_loader('cifar10', split='all', batch_size=256, DA=True)
loader2 = get_loader('cifar10', split='test', batch_size=128, DA=False)
N = len(loader2.dataset)
#------------------------#
encoder = SEW_ResNet(4)
projector = Projector()
model = SimCLR(encoder, projector).to(device)
classifier = Classifier(128, 10).to(device)
#------------------------#
optimizer = optim.SGD(model.parameters(), lr=0.6, momentum=0.9)
optimizer2 = optim.Adam(classifier.parameters())
scheduler1 = lrs.LinearLR(optimizer, 1, 0.5, 4)
#scheduler2 = lrs.CosineAnnealingLR(optimizer, 15, 0.1)
#scheduler = lrs.ChainedScheduler([scheduler1, scheduler2], optimizer)
#------------------------#
criterion = NT_Xent(256, 0.2, device).to(device)
criterion2 = nn.CrossEntropyLoss()

wandb.login()
run = wandb.init(
    project = 'SimCLR SEW 1102',
    config = {
        'Architecture': 'SEWResNet14',
        'projector': 'ANN1layer',
        'feature dim': 128,
        'embedding dim': 64,
        'T': 4,
        'optim': 'SGD',
        'momentum': 0.9,
        'lr': 0.6,
        'sche1': 'Linear(1, 0.5, 4)',
        'sche2': 'None',
        'sche': 'None',
        'change sche': 'None',
        'criterion': 'NT-Xent',
        'tau': 0.2,
        'Data': 'Cifar10',
        'batch': 256,
        'else': 'groups=2, down_sample.stride=2, ADD'
    }
)

#train
start_epoch = 0
epochs = 32
for epoch in range(start_epoch, epochs):
    loss = train_cl(loader, model, optimizer, scheduler1, criterion, device)
    correct = train_in_cl(loader2, model.encoder, classifier, optimizer2, criterion2, device)
    wandb.log({'loss': loss, 'acc': correct*100/N})
    print(f'Epoch: {epoch} | loss: {loss} | acc: {correct*100/N}%')

wandb.finish()
save_checkpoint('SimCLR_by_SEW_1110_ADD_3', model, optimizer, scheduler1)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 41.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mtk-0311-h[0m ([33mtk-0311-h-hosei-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 234/234 [03:12<00:00,  1.22it/s]


Epoch: 0 | loss: 702.6946165561676 | acc: 20.96%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 1 | loss: 634.743766784668 | acc: 19.99%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 2 | loss: 535.2988519668579 | acc: 25.15%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 3 | loss: 477.702418923378 | acc: 27.05%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 4 | loss: 442.0260227918625 | acc: 29.98%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 5 | loss: 425.7961231470108 | acc: 30.93%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 6 | loss: 413.22518742084503 | acc: 33.29%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 7 | loss: 403.3268154859543 | acc: 33.67%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 8 | loss: 394.5040558576584 | acc: 36.23%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 9 | loss: 388.5546405315399 | acc: 37.12%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 10 | loss: 383.49152982234955 | acc: 37.54%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 11 | loss: 378.68061113357544 | acc: 38.92%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 12 | loss: 375.6810508966446 | acc: 40.03%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 13 | loss: 372.57968175411224 | acc: 39.16%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 14 | loss: 368.9892827272415 | acc: 40.6%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 15 | loss: 366.4579008817673 | acc: 41.05%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 16 | loss: 364.2126432657242 | acc: 41.72%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 17 | loss: 361.8333238363266 | acc: 42.46%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 18 | loss: 360.60885059833527 | acc: 42.34%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 19 | loss: 358.78042018413544 | acc: 42.85%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 20 | loss: 356.40119230747223 | acc: 43.36%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 21 | loss: 355.54558634757996 | acc: 44.36%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 22 | loss: 354.7785128355026 | acc: 43.21%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 23 | loss: 353.0196750164032 | acc: 43.31%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 24 | loss: 352.4152147769928 | acc: 44.38%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 25 | loss: 350.8228083848953 | acc: 44.28%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 26 | loss: 350.75747632980347 | acc: 44.63%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 27 | loss: 349.05703341960907 | acc: 45.18%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 28 | loss: 348.6615844964981 | acc: 45.13%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 29 | loss: 347.6818438768387 | acc: 44.44%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 30 | loss: 346.4684873819351 | acc: 45.82%


100%|██████████| 234/234 [03:09<00:00,  1.23it/s]


Epoch: 31 | loss: 346.2931708097458 | acc: 44.85%


0,1
acc,▁▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█▇▇████████
loss,█▇▅▄▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,44.85
loss,346.29317


**11/10**

encoder : SEW ResNet13, group=2, down_sample.stride=2

projector : 1layer ANN

optimizer : SGD, lr=1

scheduler : Linear(1->0.3) + Cosine(0.3->0.1)

Else : T=4, batch=256

enviroment : L4 GPU, 3min/epoch, 12.1/22 GB

-> smooth leaf 13

# Inspection

In [None]:
l = get_loader('cifar10', split='test', batch_size=64, DA=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 47.6MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
#Analysis
device = torch.device('cpu')
e = SEW_ResNet(4)
p = Projector()
m = SimCLR(e, p)
cl = Classifier(128, 10)
cr = NT_Xent(64, 0.2, device)
#load state dict
#cp = torch.load('SimCLR_by_SEW_1109_ADD.pth', weights_only=True, map_location=device)
#m.load_state_dict(cp['model_sd'])

(d1, d2), t = next(iter(l))
z1, z2 = m(d1, d2)
loss = cr(z1, z2)
print(loss)

tensor(2.4113, grad_fn=<DivBackward0>)


In [None]:
# Result
# T4 | 13.2 / 15 GB | batch=512, T=2 | 2m / epoch | 1epで10%, 2epで14%
# L4 | 28.9 / 22 GB | same | 1epで14%
# TPU | diagでエラー | positiveのdiagをno_grad()の中に だめ -> loaderをparallel_loaderへ(T=1) 1.5min/epochできた
# TPU | 上のままTを上げていく T=2でダメ -> mixed precision で T=2 OK, 2min/epoch
# 上の記録 | 13.23 -> 16.35 -> 18.44 -> 19.58 -> 19.96 -> 21.24 > 21.53 > 21.8

# down_sampleのstride=2, groups=2でやろう. paramsが1200k->500k
# T4 GPU | batch=256, T=3(9.8GB) OK(13%) | T=4(12.7GB) できる(13%) | T=5は無理そう
# TPU | batch=256, T=1 OK(10.4%) | T=2 OK(12.2%) | T=3 OK(13.2%) | T=4 NO(Exhausted)
# ひとまず TPU N=256, T=3, epochs=32, 20%でsaturation
# projectorをANNに変えたらResorce Exhausted
# L4 GPU | N=256 T=4 epochs=16 | 3min/epoch -> epoch5 17%でsaturation | z1, z2は異常なし

# 11/7 z1,z2問題解決
# hyper parameterは上と同じ | L4 GPUで3.1min/epoch | RAMは半分 -> まだいけるぞ |
# 34%まではうまくいった -> lrが小さすぎて坂を上れなかったか -> T-wiseでlossとる | どうであれもっとepoch増やしてから分析
# 11/9 state dictをsaveしっかり | 最初の8epochまでは上と変えてないがloss, accともに異なる曲線に -> 不安定 -> やはりT-wiseにするしか?
# 11/10
# 保存しっかり