In [1]:
import torch
import torchvision
import os
import random

import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.io import read_image

from transformers import BeitForImageClassification, BeitConfig, BeitFeatureExtractor, Trainer, TrainingArguments
from PIL import Image

from tqdm import tqdm
from collections import defaultdict

In [2]:
# Check for GPU
device = ("cuda:0" if torch.cuda.is_available() else "cpu")

device

'cuda:0'

In [3]:
class food_set(Dataset):

    def __init__(self, labels_file, img_dir, extractor, transform = None, n = None):
        self.n = n
        if n:
            self.img_labels = pd.read_csv(labels_file, names=['img_name', 'label'], header=1).iloc[n]
        else:
            self.img_labels = pd.read_csv(labels_file, names=['img_name', 'label'], header=1)
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform
        
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
                
        features = self.feature_extractor(images=image)["pixel_values"][0]
        label = self.img_labels.iloc[idx, 1]       
            
        return features, label

In [4]:
class food_test(Dataset):

    def __init__(self, img_dir, extractor, transform = None):
        self.img_dir = img_dir
        self.feature_extractor = extractor
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.img_dir))

    def __getitem__(self, idx):
        file_name = os.listdir(self.img_dir)[idx]
        img_path = os.path.join(self.img_dir, file_name)
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
        
        features = self.feature_extractor(images=image)["pixel_values"][0]
        
        return file_name, features

In [6]:
# Load BEiT
beit = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")

beit.classifier = torch.nn.Linear(768, 81)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [7]:
def train_loop(dataset, model, criterion, optimizer, labels, 
               img_dir, feature_extractor = None, transform = None, train_size = 0.8):
    # Train on GPU if available
    model = model.to(device)
    model.train()
    
    # Metadata
    stats = defaultdict(list)
    current_acc = 0
    max_acc = 0
    
    stats["epoch"].append(0)
    stats["loss"].append(8.7)
    stats["accuracy"].append(1.25)
        
    for epoch in range(40):
        # Calculate size of train sample
        train_sample = int(len(dataset) * train_size)
        
        # Get indeces of train and validation set
        n_train = random.sample(set(np.arange(len(dataset))), train_sample)
        n_val = range(len(dataset))
        n_val = list(set(n_val) - set(n_train))
        
        # Get the samples at those indeces
        trainset = food_set(labels, img_dir, feature_extractor, transform = transform, n = n_train)
        valset = food_set(labels, img_dir, feature_extractor, transform = transform, 
                          n = n_val)
        
        # Create a DataLoader with the data
        trainloader = DataLoader(trainset, batch_size=16, shuffle=True, num_workers=0)
        valloader = DataLoader(valset, batch_size=16, shuffle=True, num_workers=0)
        
        size = len(trainloader.dataset)
        
        for batch, (image, label) in enumerate(trainloader):
            # Compute prediction and loss
            image, label = image.to(device), label.to(device)            
            pred = model(image)
            loss = criterion(pred["logits"], label)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print progress
            if batch % 200 == 0:
                loss, current = loss.item(), batch * len(image)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}] epoch: {epoch + 1}")
        
        # Print performance after current number of epochs
        print(f"Epoch: {epoch + 1}")
        current_acc, avg_loss = test_loop(valloader, model, criterion)
        
        if current_acc > max_acc:
            PATH = f'../../beit_weights.pth'
            torch.save(beit.state_dict(), PATH)
            max_acc = current_acc
            
        # Store metadata
        stats["epoch"].append(epoch + 1)
        stats["accuracy"].append(current_acc)
        stats["loss"].append(avg_loss)
        
        # Plot progress
        fig, ax1 = plt.subplots()

        color = "tab:blue"
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss", color=color)
        ax1.plot(stats["epoch"], stats["loss"], color=color)
        ax1.tick_params(axis="y", labelcolor=color)

        ax2 = ax1.twinx()

        color = "tab:orange"
        ax2.set_ylabel("Accuracy", color=color)
        ax2.plot(stats["epoch"], stats["accuracy"], color=color)
        ax2.tick_params(axis="y", labelcolor=color)
        plt.show()
    
    return stats


def test_loop(dataloader, model, criterion):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for image, label in dataloader:
            image, label = image.to(device), label.to(device)
            pred = model(image)
            test_loss += criterion(pred["logits"], label).item()
            correct += (pred["logits"].argmax(1) == label).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return (100 * correct), test_loss

In [8]:
# CrossEntropyLoss and SGD optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(beit.parameters(), lr = 0.001, momentum = 0.9)
# optimizer = optim.Adam(beit.parameters(), lr=0.00001)

In [13]:
img_dir = "../data/train_set/train_set"
labels = "../data/train_labels.csv"

transform = transforms.Compose(
                    [transforms.RandomApply([transforms.ColorJitter()], p=0.3),
                     transforms.RandomApply([transforms.Grayscale(3)], p=0.3),
                     transforms.RandomApply([transforms.RandomAffine(180)], p=0.3),
                     transforms.RandomHorizontalFlip()])

feature_extractor = BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224")
dataset = food_set(labels, img_dir, feature_extractor)

In [14]:
test_set = food_test("../data/test_set/test_set", feature_extractor)
testloader = DataLoader(test_set, batch_size = 16)

len(test_set)

7653

In [None]:
stats = train_loop(dataset, beit, criterion, optimizer, 
                   labels, img_dir, feature_extractor)

loss: 4.671617  [    0/24488] epoch: 1


In [27]:
df_stats = pd.DataFrame(stats)
df_stats.to_csv("../../stats.csv")

In [13]:
# PATH = './transfered_beit.pth'
# torch.save(beit.state_dict(), PATH)

beit.load_state_dict(torch.load("../../other_beit_optimal.pth"))
beit.to(device)

BeitForImageClassification(
  (beit): BeitModel(
    (embeddings): BeitEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BeitEncoder(
      (layer): ModuleList(
        (0): BeitLayer(
          (attention): BeitAttention(
            (attention): BeitSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (relative_position_bias): BeitRelativePositionBias()
            )
            (output): BeitSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (interme

In [18]:
test_results = {"img_name" : [], "label": []}

for file_name, image in tqdm(testloader):
    X = image.to(device)
    pred = beit(X)["logits"].argmax(1)
    
    test_results["img_name"] += file_name
    test_results["label"] += [int(i) for i in pred.cpu()]

100%|████████████████████████████████████████| 479/479 [02:40<00:00,  2.98it/s]


In [19]:
test_df = pd.DataFrame(test_results)

test_df.head()

Unnamed: 0,img_name,label
0,test_1.jpg,15
1,test_10.jpg,45
2,test_100.jpg,75
3,test_1000.jpg,29
4,test_1001.jpg,18


In [20]:
test_df.to_csv("../../submission_beit_other.csv", index=False)