<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/Template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip install spikingjelly
#!pip install wandb

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler as lrs
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms, datasets
'''
import spikingjelly
from spikingjelly import layer as jnn
from spikingjelly import neuron
from spikingjelly import functional as jF
'''
from tqdm import tqdm
import wandb

# Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, inplane, outplane, down_sampling=True):
        super(BasicBlock, self).__init__()
        self.down_sampling = down_sampling
        self.stride = 2 if down_sampling else 1
        self.conv1x1 = nn.Conv2d(inplane, outplane, 1, 2, bias=False)

        layer = []
        layer.append(nn.BatchNorm2d(inplane))
        layer.append(nn.ReLU())
        layer.append(nn.Conv2d(inplane, outplane, 3, self.stride, 1, bias=False))
        layer.append(nn.BatchNorm2d(outplane))
        layer.append(nn.ReLU())
        layer.append(nn.Conv2d(outplane, outplane, 3, 1, 1, bias=False))
        self.layer = nn.Sequential(*layer)

    def forward(self, x):
        identity = x
        x = self.layer(x)
        if self.down_sampling:
            identity = self.conv1x1(identity)
        x += identity
        return x

class ResNet18(nn.Module):
    def __init__(self, classes=10):
        super(ResNet18, self).__init__()

        self.first = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1, bias=False),
            nn.MaxPool2d(2)
        )
        self.block1 = BasicBlock(64, 64, False)
        self.block2 = BasicBlock(64, 64, False)
        self.block3 = BasicBlock(64, 128, True)
        self.block4 = BasicBlock(128, 128, False)
        self.block5 = BasicBlock(128, 256, True)
        self.block6 = BasicBlock(256, 256, False)
        self.block7 = BasicBlock(256, 512, True)
        self.block8 = BasicBlock(512, 512, False)
        self.last = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(p=0.2),
            nn.Linear(1*1*512, classes, bias=False)
        )

    def forward(self, x):
        x = self.first(x)
        x = self.block1(x) #[64, 32, 32] -> [64, 32, 32]
        x = self.block2(x)
        x = self.block3(x) #[64, 32, 32] -> [128, 16, 16]
        x = self.block4(x)
        x = self.block5(x) #[128, 16, 16] -> [256, 8, 8]
        x = self.block6(x)
        x = self.block7(x) #[256, 8, 8] -> [512, 4, 4]
        x = self.block8(x)
        x = self.last(x)
        return x

In [None]:
class SimCLR(nn.Module):
    def __init__(self, encoder):
        super(SimCLR, self).__init__()
        self.encoder = encoder
        self.projector = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 64)
        )

    def forward(self, xi, xj):
        hi = self.encoder(xi)
        hj = self.encoder(xj)
        zi = self.projector(hi)
        zj = self.projector(hj)
        return zi, zj

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__
        self.layer = nn.Sequential(
            nn.Linear(512, 10, bias=False)
        )
    def forward(self, x):
        y = self.layer(x)
        return y

In [None]:
class NT_Xent(nn.Module):
    def __init__(self, batch_size, tau, device):
        super(NT_Xent, self).__init__()
        self.batch_size = batch_size
        self.tau = tau
        self.device = device
        self.mask = self.make_mask()
        self.cosine = nn.CosineSimilarity(dim=2)
        self.Xent = nn.CrossEntropyLoss()

    def make_mask(self):
        mask = torch.eye(2 * self.batch_size)
        ones = torch.ones(self.batch_size)
        mask = mask + torch.diag(ones, self.batch_size) + torch.diag(ones, -self.batch_size)
        return ~mask.bool()

    def forward(self, zi, zj):
        z = torch.cat((zi, zj), dim=0)
        similarity = self.cosine(z.unsqueeze(1), z.unsqueeze(0)) / self.tau

        sim_ij, sim_ji = torch.diag(similarity, self.batch_size), torch.diag(similarity, -self.batch_size)
        positive = torch.cat((sim_ij, sim_ji), dim=0).reshape(2*self.batch_size, 1)
        negative = similarity[self.mask].reshape(2*self.batch_size, -1)

        labels = torch.zeros(2*self.batch_size, dtype=torch.long).to(self.device)
        logits = torch.cat((positive, negative), dim=1)
        loss = self.Xent(logits, labels)

        return loss / 2

# Utils

In [None]:
class DataAugmentation:
    def __init__(self):
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.tf = transforms.Compose([
            transforms.RandomResizedCrop(32, (0.49, 1)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        return self.tf(x), self.tf(x)

In [None]:
def get_loader(data='cifar10', split='train', DA=False, batch_size=128):
    tf = DataAugmentation() if DA else transforms.ToTensor()
    if data == 'cifar10':
        match split:
            case 'train':
                data = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
            case 'test':
                data = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
            case 'all':
                train = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
                test = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
                data = ConcatDataset([train, test])
        loader = DataLoader(data, batch_size, shuffle=True, drop_last=True, num_workers=2)
        return loader
    else:
        print(f'{data} is not supported >_<')

In [None]:
def train_(loader, model, optimizer, scheduler, criterion, device, SNN=False):
    #jF.reset_net(model) if SNN
    running_loss = 0
    correct = 0
    model.train()
    for data, target in tqdm(loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        out = model(data)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        #xm.optimizer_step(optimizer)
        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

In [None]:
def train_cl(loader, model, optimizer, scheduler, criterion, device, SNN=False):
    #jF.reset_net(model) if SNN
    running_loss = 0
    model.train()
    for (xi, xj), _ in tqdm(loader):
        optimizer.zero_grad()
        xi, xj = xi.to(device), xj.to(device)
        zi, zj = model(xi, xj)
        zi, zj = zi.to(device), zj.to(device)
        loss = criterion(zi, zj)
        loss.backward()
        optimizer.step()
        #xm.optimizer_step(optimizer)
        running_loss += loss.item()
    scheduler.step()
    return running_loss

In [None]:
def train_in_cl(loader, encoder, classifier, optimizer, criterion, device, SNN=False):
    #jF.reset_net(encoder) if SNN
    #jF.reset_net(classifier) if SNN
    correct = 0
    encoder.eval()
    classifier.train()
    for data, target in loader:
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            z = encoder(data)
        out = classifier(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        #xm.optimizer_step(optimizer)
        correct += (out.argmax(1) == target).sum().item()
    return correct

In [None]:
def save_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = {
        'model_sd': model.state_dict(),
        'optimizer_sd': optimizer.state_dict(),
        'scheduler_sd': scheduler.state_dict()
    }
    torch.save(checkpoint, f'{filename}.pth')

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(f'{filename}.pth')
    model.load_state_dict(checkpoint['model_sd'])
    optimizer.load_state_dict(checkpoint['optimizer_sd'])
    scheduler.load_state_dict(checkpoint['scheduler_sd'])

# Main

In [None]:
#wandb.login()

#instance
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch_xla.device()
loader = get_loader('cifar10', split='train', DA=False, batch_size=128)
N = len(loader.dataset)
resnet = ResNet18().to(device)
#simclr = SimCLR(encoder)
optimizer = optim.Adam(resnet.parameters())
#optimizer2 = optim.Adam(classifier.parameters())
scheduler1 = lrs.LinearLR(optimizer, start_factor=0.01, total_iters=8)
scheduler2 = lrs.CyclicLR(optimizer, base_lr=1e-4, max_lr=5e-4, step_size_up=5)
scheduler = lrs.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[8])
criterion = NT_Xent(512, 0.2, device)
criterion2 = nn.CrossEntropyLoss()

'''
#wandb
run = wandb.init(
    project = 'name',
    config = {
        'Architecture': 'x',
        'optim': 'Adam(1e-3)',
        'sche1': 'x',
        'sche2': 'x',
        'sche': 'x',
        'criterion': 'x',
        'Data': 'Cifar10',
        'else': 'x'
    }
)
'''
#train
start_epoch = 0
epochs = 8
#jF.set_step_mode(model, 'm')
for epoch in range(start_epoch, epochs):
    loss, correct = train_(loader, resnet, optimizer, scheduler, criterion2, device)
    #wandb.log({'loss': loss, 'acc': acc})
    print(f'Epoch: {epoch} | loss: {loss} | acc: {correct/N}%')

#wandb.finish()
#save_checkpoint('name', simclr, optimizer, scheduler)

In [None]:
# Result
#