<a href="https://www.kaggle.com/code/mohankumarmanepalli/vit-oct-classify?scriptVersionId=194673722" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

In [None]:
!pip install vit-pytorch
import gc
import time
import shutil
from tqdm.notebook import tqdm
from vit_pytorch import SimpleViT
import torch
from torch import nn
from torchvision import transforms
# from transformers import ViTForImageClassification, ViTConfig, AutoImageProcessor
from torch.utils.data import DataLoader, random_split
from torch import optim
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

In [8]:
lr = 1e-4
device = xm.xla_device() # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 32
img_size=224
epochs = 100

writer = SummaryWriter()
train_ds_path = '/kaggle/input/kermany2018/OCT2017 /train'
test_ds_path = '/kaggle/input/kermany2018/OCT2017 /test'
val_ds_path = '/kaggle/input/kermany2018/OCT2017 /val'

In [9]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(img_size),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_gen = ImageFolder(train_ds_path, transform=transform)
test_gen = ImageFolder(test_ds_path, transform=transform)
val_gen = ImageFolder(val_ds_path, transform=transform)

In [10]:
train_size = int(0.5 * len(train_gen))
val_size = int(0.25 * len(train_gen))
test_size = int(0.25 * len(train_gen))

train_dataset, val_dataset, test_dataset = random_split(train_gen, [train_size, val_size, test_size])

In [12]:
import os
train_ds = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())
test_ds = DataLoader(test_gen, shuffle=False, num_workers=os.cpu_count())
val_ds = DataLoader(val_gen, shuffle=True, num_workers=os.cpu_count())

In [13]:
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.count = None
        self.sum = None
        self.avg = None
        self.val = None
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, *meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def print(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    @staticmethod
    def _get_batch_fmtstr(num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch):
    global lr
    lr = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def writer_fn(mode: str, epoch, loss, acc1, acc5):
    global writer
    writer.add_scalar(mode + "_Loss", loss, epoch)
    writer.add_scalar(mode + "_acc1", acc1, epoch)
    writer.add_scalar(mode + "_acc3", acc5, epoch)

In [14]:
def train(train_loader, model, criterion, optimizer, epoch, device):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top3 = AverageMeter('Acc@3', ':6.2f')
    progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1,
                             top3, prefix="Epoch: [{}]".format(epoch))
    model.to(device)
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(tqdm(train_loader)):
        data_time.update(time.time() - end)

        images = images.to(device)
        target = target.to(device)

        output = model(images)
        loss = criterion(output, target)

        acc1, acc3 = accuracy(output, target, topk=(1, 3))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top3.update(acc3[0], images.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end) 
        end = time.time()

        if i % 10 == 0:
            progress.print(i)
    writer_fn("Train", epoch, losses.avg, top1.avg, top3.avg)


def validate(val_loader, model, criterion, device, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top3 = AverageMeter('Acc@3', ':6.2f')
    progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top3,
                             prefix='Test: ')
    model.to(device)
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(tqdm(val_loader)):

            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)

            acc1, acc3 = accuracy(output, target, topk=(1, 3))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top3.update(acc3[0], images.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.print(i)
        writer_fn("Val", epoch, losses.avg, top1.avg, top3.avg)
        print(' * Acc@1 {top1.avg:.3f} Acc@3 {top3.avg:.3f}'
              .format(top1=top1, top3=top3))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [15]:
model=SimpleViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 4,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)
# model= nn.DataParallel(model)
optimizer = optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
best_acc1 = 0
gc.collect()
for epoch in tqdm(range(epochs)):
    gc.collect()
#     adjust_learning_rate(optimizer, epoch)
    train(train_ds, model, criterion, optimizer, epoch, device)
    validate(val_ds, model, criterion, device, epoch)

  0%|          | 0/100 [00:00<?, ?it/s]

  self.pid = os.fork()


Epoch: [0][   0/1305]	Time  1.325 ( 1.325)	Data  0.627 ( 0.627)	Loss 1.4369e+00 (1.4369e+00)	Acc@1   6.25 (  6.25)	Acc@3  87.50 ( 87.50)
Epoch: [0][  10/1305]	Time  0.096 ( 0.206)	Data  0.007 ( 0.063)	Loss 1.2708e+00 (2.3381e+00)	Acc@1  46.88 ( 38.35)	Acc@3  93.75 ( 89.77)
Epoch: [0][  20/1305]	Time  0.098 ( 0.155)	Data  0.009 ( 0.040)	Loss 1.2542e+00 (1.8711e+00)	Acc@1  31.25 ( 36.61)	Acc@3  90.62 ( 89.14)
Epoch: [0][  30/1305]	Time  0.145 ( 0.139)	Data  0.095 ( 0.036)	Loss 1.2961e+00 (1.6637e+00)	Acc@1  40.62 ( 38.81)	Acc@3  93.75 ( 89.92)
Epoch: [0][  40/1305]	Time  0.098 ( 0.131)	Data  0.011 ( 0.034)	Loss 1.3018e+00 (1.5678e+00)	Acc@1  43.75 ( 40.47)	Acc@3  84.38 ( 88.72)
Epoch: [0][  50/1305]	Time  0.143 ( 0.127)	Data  0.093 ( 0.035)	Loss 1.4102e+00 (1.5136e+00)	Acc@1  31.25 ( 40.50)	Acc@3  81.25 ( 88.54)
Epoch: [0][  60/1305]	Time  0.097 ( 0.124)	Data  0.007 ( 0.034)	Loss 1.3288e+00 (1.4708e+00)	Acc@1  40.62 ( 41.19)	Acc@3  87.50 ( 88.58)
Epoch: [0][  70/1305]	Time  0.106 ( 0.121

  self.pid = os.fork()


Test: [ 0/32]	Time  0.173 ( 0.173)	Loss 2.1341e+00 (2.1341e+00)	Acc@1   0.00 (  0.00)	Acc@3 100.00 (100.00)
Test: [10/32]	Time  0.009 ( 0.023)	Loss 2.6866e+00 (1.3505e+00)	Acc@1   0.00 ( 45.45)	Acc@3   0.00 ( 81.82)
Test: [20/32]	Time  0.008 ( 0.016)	Loss 2.4122e+00 (1.2299e+00)	Acc@1   0.00 ( 52.38)	Acc@3   0.00 ( 85.71)
Test: [30/32]	Time  0.006 ( 0.013)	Loss 2.0178e+00 (1.3234e+00)	Acc@1   0.00 ( 48.39)	Acc@3 100.00 ( 83.87)
 * Acc@1 50.000 Acc@3 84.375
Epoch: [1][   0/1305]	Time  0.551 ( 0.551)	Data  0.491 ( 0.491)	Loss 9.5561e-01 (9.5561e-01)	Acc@1  75.00 ( 75.00)	Acc@3  84.38 ( 84.38)
Epoch: [1][  10/1305]	Time  0.094 ( 0.139)	Data  0.011 ( 0.056)	Loss 1.1641e+00 (1.0263e+00)	Acc@1  50.00 ( 59.09)	Acc@3  87.50 ( 90.62)
Epoch: [1][  20/1305]	Time  0.093 ( 0.120)	Data  0.014 ( 0.038)	Loss 8.6318e-01 (1.0233e+00)	Acc@1  65.62 ( 59.97)	Acc@3  96.88 ( 90.03)
Epoch: [1][  30/1305]	Time  0.099 ( 0.113)	Data  0.006 ( 0.030)	Loss 1.2026e+00 (1.0462e+00)	Acc@1  53.12 ( 58.77)	Acc@3  93.75 

In [None]:
writer.flush()
writer.close()

In [None]:
gc.collect()

In [None]:
torch.save(model.state_dict(),'model_weights.pt')
torch.save(model,'model.pt')