In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset

from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
cifar10 = load_dataset("cifar10")

# Define a transformation to convert PIL images to tensors
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
])

# Select subset of data
train_data = cifar10["train"].select(range(200))
test_data = cifar10["test"].select(range(100))

# Create PyTorch Dataset class
class Cifar10Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample["img"]
        label = sample["label"]
        if self.transform:
            print(image.size)
            image = self.transform(image.convert("RGB")) # Convert PIL image to Tensor

        return image, label

train_dataset = Cifar10Dataset(train_data, transform=transform)

In [7]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [10]:
for images, labels in train_loader:
    print(images.shape)
    break

<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngImageFile'>
(32, 32)
<class 'PIL.PngImagePlugin.PngI

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from model import ViTforImageClassification
from dataset import get_cifar10_dataloader

from config_json import get_config

config = get_config()

# Define training parameters
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # This is for apple silicon
num_epochs = config["num_epochs"]
learning_rate = config["lr"]
save_path = "best_vit_model.pth"

# Initialize model, loss function, and optimizer
model = ViTforImageClassification(config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
train_data_loader, val_data_loader = get_cifar10_dataloader(batch_size=config["batch_size"], shuffle=True)

In [12]:
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    
    loop = tqdm(train_data_loader, leave=True)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        loss, outputs = model(images, labels)
        print(outputs)
        print(outputs.shape)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        print(predicted, predicted.shape)
        break
    break
        # total += labels.size(0)
        # correct += (predicted == labels).sum().item()

  0%|          | 0/7 [00:00<?, ?it/s]

tensor([[-3.8790e+00, -5.5612e-01, -2.1052e+00,  4.0457e+00,  8.5984e-01,
          7.7154e+00, -6.9422e-01, -5.2390e+00, -3.9817e+00, -1.9394e-01],
        [-1.3675e+00, -1.9447e-01, -4.8972e-02, -2.0157e+00,  1.3705e-02,
         -9.7101e-01, -4.3103e+00, -2.8027e+00,  1.4293e+00,  7.5641e+00],
        [ 2.7155e+00, -2.4774e+00,  2.0959e+00,  6.2772e-01,  2.0166e+00,
         -4.9911e+00,  1.6773e+00,  6.9790e+00, -2.2916e+00, -2.5475e+00],
        [ 2.6525e+00, -4.8836e-01, -1.5520e+00, -1.6419e+00, -8.4108e-01,
         -2.5846e+00, -5.0165e+00, -7.1271e-01,  8.6391e+00,  1.7626e+00],
        [ 4.9850e-01,  5.0055e+00, -1.7327e+00,  2.1353e+00, -2.7851e+00,
         -1.1528e+00,  1.7800e-01, -4.4024e+00,  1.1931e+00,  1.7681e-01],
        [-1.6306e+00,  7.6581e-01,  3.3112e+00, -1.0177e-02, -2.6529e+00,
          6.0321e+00,  2.2091e+00, -3.2034e+00, -1.4141e+00, -4.2555e+00],
        [-1.0632e+00,  5.9228e+00,  1.3692e+00, -2.6003e-01, -4.9659e+00,
         -1.5858e+00, -5.4314e-0


