# L2P impl

# 1. Datasets
    - Split CIFAR-100 : Split to 10 tasks (10 classes per task)
    - 5-datasets : CIFAR-10, MNIST, Fashion-MNIST, SVHN, notMNIST

In [114]:
from torchvision import transforms, datasets
import torch
from urllib.request import urlretrieve
import zipfile

# dataset name, transform_train, transform_val, data_path
def get_dataset(dataset, transform_train, transform_val, data_path, download):
    if dataset == "CIFAR10":
        train = datasets.CIFAR10(data_path, train=True, download=download, transform=transform_train)
        val = datasets.CIFAR10(data_path, train=False, download=download, transform=transform_train)
    elif dataset == "CIFAR100":
        train = datasets.CIFAR100(data_path, train=True, download=download, transform=transform_train)
        val = datasets.CIFAR100(data_path, train=False, download=download, transform=transform_train)
    elif dataset == "MNIST":
        train = datasets.MNIST(data_path, train=True, download=download, transform=transform_train)
        val = datasets.MNIST(data_path, train=False, download=download, transform=transform_train)
    elif dataset == "Fashion-MNIST":
        train = datasets.FashionMNIST(data_path, train=True, download=download, transform=transform_train)
        val = datasets.FashionMNIST(data_path, train=False, download=download, transform=transform_train)
    elif dataset == "SVHN":
        train = datasets.SVHN(data_path, train=True, download=download, transform=transform_train)
        val = datasets.SVHN(data_path, train=False, download=download, transform=transform_train)
    elif dataset == "notMNIST":
        root = data_path
        if download:
            # data url
            data_url = "https://github.com/facebookresearch/Adversarial-Continual-Learning/raw/main/data/notMNIST.zip"
            zip_file_path = "{}/notMNIST.zip".format(root)
            # retrieve data 
            print("Downloading notMNIST from https://github.com/facebookresearch/Adversarial-Continual-Learning/raw/main/data/notMNIST.zip")
            path, headers = urlretrieve(data_url, zip_file_path)
            # unzip
            with zipfile.ZipFile(zip_file_path, 'r') as obj:
                obj.extractall(root)
        
        train = datasets.ImageFolder("{}/notMNIST/Train".format(root), transform=transform_train)
        val = datasets.ImageFolder("{}/notMNIST/Train".format(root), transform=transform_val)
        
    else :
        raise ValueError("{} not found".format(dataset))
    return train, val

def get_transforms(is_train):
    # train dataset transform
    if is_train:
        return transforms.Compose([
                transforms.PILToTensor(),
                transforms.RandomResizedCrop(size=(224,224)),
                transforms.RandomHorizontalFlip(p=0.5),
            ])
    # test dataset transform
    else :
        return transforms.Compose([
            transforms.PILToTensor(),
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ])

# Check datasets

In [115]:
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
def print_img(dataset, idx):
    img = dataset.__getitem__(idx)[0]
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(to_pil_image(img), cmap='gray')


# Model : Vision Transformer ViT-B/16


In [116]:

import torch.nn as nn
import einops

# takes an image as an input, divide it into patches, let it through the embedding layer
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            stride=patch_size, 
            kernel_size=patch_size
        )
        
    def forward(self, input):
        x = self.proj(input)               # (n, 3, 224, 224) -> (n, embed_dim, patch_size, patch_size, ) 
        x = torch.flatten(x, start_dim=2, end_dim=3)  # (n, embed_dim, patch_size*patch_size)
        x = torch.transpose(x,1,2)           # (n, patch_size*patch_size, embed_dim)
        return x

class AttentionLayer(nn.Module):
    def __init__(self, 
                 num_head : int,
                 dim : int, 
                 qkv_bias : bool = True):
        super().__init__()
        self.num_head = num_head
        self.dim = dim
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.dropout_attn = nn.Dropout(0.)
        self.proj = nn.Linear(dim, dim)
        self.dropout_proj = nn.Dropout(0.)
        
    def forward(self, input):
        # input : (n, num_patch+1, embed_dim)
        
        num_samples, num_tokens, dim = input.shape
        if dim != self.dim :
            raise ValueError
        
        # multi-head attention
        # create qkv
        x = self.qkv(input)
        # decouple q k v 
        x = einops.rearrange(x, 'b  n (h d qkv) -> (qkv) b h n d', h=self.num_head, qkv=3)
        q,k,v = x[0], x[1], x[2]
        
        # Scaled dot product attention
        # q * k / d ** 1/2
        e = torch.einsum('bnqd, bnkd -> bnqk', q, k) / ((self.dim)**(1/2))
        x = nn.Softmax()(e)
        # * v
        x = torch.einsum('bnqk, bnvd -> bnqd',x,v)
        x = einops.rearrange(x,'b n q d -> b q (n d)')
        # linear 
        x = self.proj(x)
        
        return x
            
class TransformerEncoderBlock(nn.Module):
    def __init__(self, 
                 num_head : int,
                 embed_dim : int,
                 expansion : int = 4, 
                 drop: float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attention = AttentionLayer(num_head,embed_dim)
        ## mlp
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, expansion * embed_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(expansion * embed_dim, embed_dim)
        )
    
    def forward(self, input): #
        x = self.norm1(input)
        x = self.attention(x)
        x += input
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x += residual
        return x
        
class Head(nn.Module):
    def __init__(self,
                embed_dim : int, 
                num_class : int):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_class = num_class
        
    def forward(self, input):
        x = einops.reduce(input ,'b n e -> b e', reduction="mean")
        x = nn.LayerNorm(self.embed_dim)(x)
        x = nn.Linear(self.embed_dim, self.num_class)(x)
        return x

class TransformerEncoder(nn.Sequential):
    def __init__(self,
                 num_head:int,
                 depth:int,
                 embed_dim:int):
        super().__init__(*[TransformerEncoderBlock(num_head,embed_dim) for _ in range(depth)])
    
class VisionTransformer(nn.Module):
    def __init__(self, 
                num_head : int,
                num_class : int,
                batch_size : int,
                img_size : int = 224, 
                patch_size : int = 16,
                in_channels : int = 3,
                embed_dim : int = 768,
                depth : int = 12,
                ):
        super().__init__()
        self.embedding = PatchEmbed(img_size, patch_size, in_channels=in_channels, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.randn(batch_size,1,embed_dim))
        self.num_patch = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1,self.num_patch + 1,embed_dim))
        self.transformer_encoder = TransformerEncoder(num_head, depth, embed_dim)
        self.head = Head(embed_dim, num_class)

    def forward(self, input): # input (img): (batch, 3, H, W)

        x = self.embedding(input) # (batch, num_patches, embed_dim)

        # add cls (batch, num_patches, embed_dim) + (batch, 1, embed_dim) = (batch,num_patches+1, embed_dim)
        print("after embed", x.shape)
        x = torch.cat((self.cls_token, x), dim=1)
        # pos (batch, num_patches+1, embed_dim) + (num_patches+1, embed_dim)
        x += self.pos_embedding # (batch, num_patches+1, embed_dim)
        
        x = self.transformer_encoder(x)
        x = self.head(x)
        return x

# Train

In [117]:
class Config:
    def __init__(self):
        self.num_class = 1000
        self.batch_size = 128
        self.img_size = 224
        self.patch_size = 16
        self.in_channels = 3
        self.embed_dim = 768
        self.depth = 12
        self.num_head = 12
        
        self.datasets = "CIFAR100"
        self.data_path = "C:/Users/99san/Workspace/L2P/data"
        self.summary_path = "C:/Users/99san/Workspace/L2P/summary"
        
        self.lr = 0.001
        self.lr_momentum = 0.9
        self.weight_decay = 5e-4
        
        self.start_epoch = 0
        self.epochs = 100
    

In [118]:
def accuracy(output, y):
    pred = output.topk(k=1).values
    pred = torch.reshape(pred, (-1,))
    guess = pred - y
    acc = len((guess == 0).nonzero()) / len(y) 
    return acc
    
def train(train_loader, model, criterion, optimizer, epoch, writer):
    losses = 0.
    accs = 0.
    # train mode
    model.train()
    # one epoch
    for i, (input, target) in enumerate(train_loader):
        target = target.cuda()
        input_var = input.cuda()
        target_var = target
        
        #forward
        output = model(input_var)
        loss = criterion(output, target_var)
        
        # compute grad
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        output = output.float()
        losses += loss.float()
        accs += accuracy(output.data, target)[0]
        
    losses /= len(train_loader)
    accs /= len(train_loader)
    writer.add_scalar("Loss/train", losses, epoch)
    writer.add_scalar("Accuracy/train", accs, epoch)
    
    return accs
    
def validate(val_loader, model, criterion, epoch, writer):
    losses = 0.
    accs = 0.

    # evaluation mode
    model.eval()
    
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)
            output = output.float()
            loss = loss.float()
            
            # measure accuracy and record loss
            accs += accuracy(output.data, target)[0]
            losses += loss
            
        losses /= len(train_loader)
        accs /= len(train_loader)
        writer.add_scalar("Loss/val", losses, epoch)
        writer.add_scalar("Accuracy/val", accs, epoch)
        
    return accs

In [123]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter 

configs = Config()
model = VisionTransformer(configs.num_head, configs.num_class, configs.batch_size)

train_transforms = get_transforms(is_train=True)
val_transforms = get_transforms(is_train=False)

train_datasets, val_datasets = get_dataset("notMNIST", train_transforms, val_transforms, data_path=configs.data_path, download=False)

train_loader = DataLoader(
    train_datasets,
    batch_size = configs.batch_size,
    shuffle = True,
)

val_loader = torch.utils.data.DataLoader(
    train_datasets,
    batch_size = configs.batch_size,
    shuffle = True,
)

writer = SummaryWriter()

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), configs.lr, weight_decay=configs.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

if not torch.cuda.is_available():
    print("cuda not enabled")
    raise ValueError

for epoch in range(configs.start_epoch, configs.epochs):
    print("current lr {:.5e}".format(optimizer.param_groups[0]['lr']))
    train(train_loader, model, criterion, optimizer, epoch, writer)
    lr_scheduler.step()
    avg_acc = validate(val_loader, model, criterion, epoch)
    
    is_best = avg_acc > best_acc
    if is_best:
        torch.save(model.state_dict(), "{}/best.pth".format(configs.summary_path))
        best_acc = avg_acc

    

cuda not enabled


ValueError: 