In [None]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
print(torch.cuda.current_device())
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [None]:
dataset_path = 'dataset/2018_train_mini'
original_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

In [None]:
print("Total number of images: ", len(original_dataset))
print("Classes number: ", len(original_dataset.classes))

In [None]:
import matplotlib.pyplot as plt
import random

for i in range(4):
    image, label = original_dataset[random.randint(0, len(original_dataset)-1)]
    
    image = image.permute(1, 2, 0).numpy()
    plt.figure(figsize=(1, 1))
    plt.imshow(image)
    plt.title(f"Class: {original_dataset.classes[label]}")
    plt.show()

In [None]:
# only use a subset of the dataset
ori_len = len(original_dataset)
train_size = int(0.7 * ori_len)
val_size = int(0.15 * ori_len)
test_sieze = ori_len - train_size - val_size

train_data, val_data, test_data = random_split(original_dataset, [train_size, val_size, test_sieze])

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True)

In [None]:
import torch.nn as nn
import torch.optim as optim

# load the pretrained backbone and head
backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg', pretrained=True)
head = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc', pretrained=True)

# freeze backbone
for param in backbone.parameters():
    param.requires_grad = False



In [None]:
class MyModel(nn.Module):
    def __init__(self, backbone, head):
        super(MyModel, self).__init__()
        self.backbone = backbone
        self.head = head

        # to fit the number of classes in the dataset
        self.head.linear_head = nn.Linear(in_features=1920, out_features=10000)
        
    def forward(self, x):
        # get the intermediate layers of the backbone
        intermediate_layers = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
        
        # get the class tokens and the last patch tokens
        class_tokens = [layer[1] for layer in intermediate_layers]
        patch_tokens = intermediate_layers[-1][0]
        
        # mean pooling the patch tokens
        pooled_tokens = patch_tokens.mean(dim=1)

        # combine the class tokens and the pooled tokens
        linear_input = torch.cat(class_tokens + [pooled_tokens], dim=-1)

        # forward the linear head
        out = self.head.linear_head(linear_input)
        
        return out

In [None]:
# to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel(backbone, head).to(device)

# define loss and optimizer
criterion = nn.CrossEntropyLoss()
# only train the head
optimizer = optim.Adam(head.parameters(), lr=1e-4)

In [None]:
from torchmetrics import Accuracy
from tqdm import tqdm

# train the model
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # forward
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        

        # backward and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(train_loader)

# define validation function
def evaluate(model, loader, device):
    model.eval()
    accuracy_metric = Accuracy(task='multiclass', top_k=1, num_classes=10000).to(device) # top-1 accuracy
    with torch.no_grad():
        for inputs, labels in tqdm(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            accuracy_metric.update(outputs, labels)
            
    return accuracy_metric.compute().item()

In [None]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
num_epochs = 8
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_accuracy = evaluate(model, val_loader, device)
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Val Accuracy: {val_accuracy * 100:.2f}%')
    
    # save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')

In [None]:
# evaluate the best model on the test dataset   
model.load_state_dict(torch.load('best_model.pth'))
test_accuracy = evaluate(model, test_loader, device)
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')