In [1]:
import time

import PIL
import torch
import torchvision
import torch.nn.functional as F
from torch import nn
import torch.nn.init as init

from einops import rearrange
from common import *
from models.vt_resnet18 import VTResNet18
from TinyImageNet import TinyImageNet
from models.resnet import BasicBlock

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
PATH_TO_IMAGE_NET = "./data/tiny-imagenet-200"
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_VAL = 100
device = torch.device("cuda")

transform_train = torchvision.transforms.Compose(
     [torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.RandomRotation(10, resample=PIL.Image.BILINEAR),
     torchvision.transforms.RandomAffine(8, translate=(.15,.15)),
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])


transform_val = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

train_dataset = TinyImageNet(PATH_TO_IMAGE_NET, split='train', transform=transform_train, in_memory=False)
val_dataset = TinyImageNet(PATH_TO_IMAGE_NET, split='val', transform=transform_val, in_memory=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE_VAL, shuffle=False)



In [5]:
model = VTResNet18(
    resnet_block=BasicBlock,
    layers=[2, 2, 2, 2], 
    tokens=16,
    token_channels=1024,
    input_dim=224,
    layer_channels=[64, 128, 256, 512],
    num_classes=200

)
model = model.to(device)
EPOCHS = 90
check_on_dataset(model, train_loader, val_loader, EPOCHS, "TinyImageNet", "fixed_ViTResNet18")


Epoch: 1
Execution time: 260.33 seconds
Epoch: 2
Execution time: 262.75 seconds
Epoch: 3
Execution time: 264.94 seconds
Epoch: 4
Execution time: 266.89 seconds
Epoch: 5
Execution time: 265.26 seconds
Epoch: 6
Execution time: 261.08 seconds
Epoch: 7
Execution time: 260.80 seconds
Epoch: 8
Execution time: 257.96 seconds
Epoch: 9
Execution time: 264.34 seconds
Epoch: 10
Execution time: 269.41 seconds

Average train loss: 5.2425

Train accuracy: 0.7960

Average test loss: 5.2382

Test accuracy: 0.7700
Saved model's checkpoint
Epoch: 11
Execution time: 267.40 seconds
Epoch: 12


KeyboardInterrupt: 