In [None]:
import torch
import torchvision
import os

import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.io import read_image

from transformers import ViTFeatureExtractor, ViTForImageClassification
#from transformers import BitForImageClassification, BeitConfig, BeitFeatureExtractor, Trainer, TrainingArguments
from PIL import Image

from tqdm import tqdm
from collections import defaultdict

In [None]:
# Check for GPU
device = ("cuda:0" if torch.cuda.is_available() else "cpu")

device

In [None]:
class food_set(Dataset):

    def __init__(self, labels_file, img_dir, extractor, transform = None, settype = "train"):
        if settype == "train":
            self.img_labels = pd.read_csv(labels_file, names=['img_name', 'label'], header=1)[:30000]
        elif settype == "val":
            self.img_labels = pd.read_csv(labels_file, names=['img_name', 'label'], header=1)[30000:]
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform
        
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
                
        features = self.feature_extractor(images=image, 
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])["pixel_values"][0]
        label = self.img_labels.iloc[idx, 1]       
            
        return features, label

In [None]:
class food_test(Dataset):

    def __init__(self, img_dir, extractor):
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.img_dir))

    def __getitem__(self, idx):
        file_name = os.listdir(self.img_dir)[idx]
        img_path = os.path.join(self.img_dir, file_name)
        image = Image.open(img_path)
        
        features = self.feature_extractor(images=image, 
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])["pixel_values"][0]
        
        return file_name, features

In [None]:
# Load ViT

vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
vit.classifier = torch.nn.Linear(768, 81)

In [None]:
def train_loop(trainloader, valloader, model, criterion, optimizer, labels, 
               img_dir, feature_extractor = None, transform = None):
    
    
    # Train on GPU if available
    model = model.to(device)
    model.train()
    
    # Metadata
    stats = defaultdict(list)
    current_acc = 0
    max_acc = 0
    
    try:
        stats["epoch"].append(0)
        stats["loss"].append(8.7)
        stats["accuracy"].append(1.25)

        size = len(trainloader.dataset)
        
        for epoch in range(10):        
            for batch, (image, label) in enumerate(trainloader):
                # Compute prediction and loss
                image, label = image.to(device), label.to(device)            
                pred = model(image)
                loss = criterion(pred["logits"], label)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Print progress
                if batch % 200 == 0:
                    loss, current = loss.item(), batch * len(image)
                    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}] epoch: {epoch + 1}")

            # Print performance after current number of epochs
            print(f"Epoch: {epoch + 1}")
            current_acc, avg_loss = test_loop(valloader, model, criterion)

            if current_acc > max_acc:
                PATH = f'../../vit_weights.pth'
                torch.save(vit.state_dict(), PATH)
                max_acc = current_acc
            
            # Store metadata
            stats["epoch"].append(epoch + 1)
            stats["accuracy"].append(current_acc)
            stats["loss"].append(avg_loss)

            # Plot progress
            fig, ax1 = plt.subplots()

            color = "tab:blue"
            ax1.set_xlabel("Epoch")
            ax1.set_ylabel("Loss", color=color)
            ax1.plot(stats["epoch"], stats["loss"], color=color)
            ax1.tick_params(axis="y", labelcolor=color)

            ax2 = ax1.twinx()

            color = "tab:orange"
            ax2.set_ylabel("Accuracy", color=color)
            ax2.plot(stats["epoch"], stats["accuracy"], color=color)
            ax2.tick_params(axis="y", labelcolor=color)
            plt.show()

        return stats
    
    except KeyboardInterrupt:
        return stats


def test_loop(dataloader, model, criterion):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for image, label in dataloader:
            image, label = image.to(device), label.to(device)
            pred = model(image)
            test_loss += criterion(pred["logits"], label).item()
            correct += (pred["logits"].argmax(1) == label).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return (100 * correct), test_loss

In [None]:
# CrossEntropyLoss and SGD optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vit.parameters(), lr = 0.001, momentum = 0.9)
# optimizer = optim.Adam(vit.parameters(), lr=0.00001)

In [None]:
img_dir = "../data/train_set/train_set"
labels = "../data/train_labels.csv"


transform = transforms.Compose(
                    [transforms.RandomApply([transforms.ColorJitter()], p=0.3),
                     transforms.RandomApply([transforms.Grayscale(3)], p=0.3),
                     transforms.RandomApply([transforms.RandomAffine(180)], p=0.3),
                     transforms.RandomHorizontalFlip()])

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')


In [None]:
trainset = food_set(labels, img_dir, feature_extractor, transform = transform, settype = "train")
valset = food_set(labels, img_dir, feature_extractor, settype = "val")

# Create a DataLoader with the data
trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=0)
valloader = DataLoader(valset, batch_size=16, shuffle=True, num_workers=0)

In [None]:
test_set = food_test("../data/test_set/test_set", feature_extractor)
testloader = DataLoader(test_set, batch_size = 16)

len(test_set)

In [None]:
stats = train_loop(trainloader, valloader, vit, criterion, optimizer, 
                   labels, img_dir, feature_extractor)

In [None]:
df_stats = pd.DataFrame(stats)
df_stats.to_csv("../../stats_vit.csv")

In [None]:
PATH = './transfered_vit.pth'
torch.save(vit.state_dict(), PATH)

# vit.load_state_dict(torch.load("../../other_vit_optimal.pth"))
# vit.to(device)

In [None]:
test_results = {"img_name" : [], "label": []}

for file_name, image in tqdm(testloader):
    X = image.to(device)
    pred = vit(X)["logits"].argmax(1)
    
    test_results["img_name"] += file_name
    test_results["label"] += [int(i) for i in pred.cpu()]

In [None]:
test_df = pd.DataFrame(test_results)

test_df.head()

In [None]:
# test_df.to_csv("submission_vit_standard.csv", index=False)
test_df.to_csv("submission.csv", index=False)