In [None]:
!pip install albumentations timm pandas

### Train

In [None]:
import warnings
warnings.filterwarnings('ignore')

import os
import pandas as pd
import numpy as np
import albumentations as A
import torch
import cv2
import timm
import random

from albumentations.pytorch import ToTensorV2
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.auto import notebook_tqdm

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, mode='train', transform=None):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mode = mode

        if mode == 'train':
            self.label_encoder = LabelEncoder()
            self.df['label'] = self.label_encoder.fit_transform(self.df['label'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name)
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        label = self.df.iloc[idx, 2] if self.mode == 'train' else -1
        
        return image, label

In [None]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

train_transform = A.Compose([
    A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

In [None]:
train_dataset = CustomDataset(csv_file="./data/train.csv", root_dir="./data/", mode='train', transform=train_transform)
test_dataset = CustomDataset(csv_file="./data/test.csv", root_dir="./data/", mode='test', transform=test_transform)

total_train_samples = len(train_dataset)
val_size = int(0.1 * total_train_samples)
train_size = total_train_samples - val_size

train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=12)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False, num_workers=12)

In [None]:
model = timm.create_model("timm/maxvit_large_tf_224.in1k", pretrained=True, num_classes=25)
model.to(device)
model = torch.nn.DataParallel(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True, min_lr=1e-6)

In [None]:
def accuracy(pred, true):
  _, preds = torch.max(pred, dim=1)
  return torch.tensor(torch.sum(preds == true).item() / len(preds))

def train(model, dataloader, criterion, optimizer, device):
    model.train()

    train_loss = 0.0
    train_acc = 0.0

    tqdm_bar = notebook_tqdm(dataloader, desc='Training')
    for batch_idx, (images, labels) in enumerate(tqdm_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += accuracy(outputs, labels).item()

        avg_loss = train_loss / (batch_idx + 1)
        avg_acc = train_acc / (batch_idx + 1)
        
        tqdm_bar.set_postfix(
            {'loss': f'{avg_loss:.5f}',
            'accuracy': f'{avg_acc:.5f}'}
        )
    return train_loss / len(dataloader), train_acc / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()

    val_loss = 0.0
    val_acc = 0.0

    tqdm_bar = notebook_tqdm(dataloader, desc='Training')
    for batch_idx, (images, labels) in enumerate(tqdm_bar):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        val_loss += loss.item()
        val_acc += accuracy(outputs, labels).item()
        
        avg_loss = val_loss / (batch_idx + 1)
        avg_acc = val_acc / (batch_idx + 1)
        
        tqdm_bar.set_postfix(
            {'loss': f'{avg_loss:.5f}',
            'accuracy': f'{avg_acc:.5f}'}
        )
    return val_loss / len(dataloader), val_acc / len(dataloader)

In [None]:
best_weight_path = "./weights/maxvit-v1-best.pth"
current_weight_path = "./weights/maxvit-v1-current.pth"

patience = 10
num_epochs = 100

if os.path.exists(current_weight_path):
    checkpoint = torch.load(current_weight_path)
    best_checkpoint = torch.load(best_weight_path)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    best_val_loss = best_checkpoint['best_val_loss']
    early_stopping_counter = checkpoint['early_stopping_counter']
    start_epoch = checkpoint['epoch'] + 1
    print('Loaded model from last checkpoint')
    print("Last best validation loss: ", best_val_loss)
    print("Continuing from epoch: ", start_epoch)
    print("Early stopping counter: ", early_stopping_counter)
else:
    best_val_loss = float('inf')
    early_stopping_counter = 0
    start_epoch = 0

In [None]:
for epoch in range(start_epoch, num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1}/{num_epochs}, LR: {current_lr}")
    print(f"Train Loss: {train_loss:.5f}, Train Accuracy: {train_acc:.5f}")
    print(f"Val Loss: {val_loss:.5f}, Val Accuracy: {val_acc:.5f}")

    scheduler.step(val_loss)

    current_checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'early_stopping_counter': early_stopping_counter
    }

    torch.save(current_checkpoint, current_weight_path)
    print('Model saved')

    ''' Save the weights with the best validation loss '''
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'early_stopping_counter': early_stopping_counter
        }

        torch.save(best_checkpoint, best_weight_path)
        early_stopping_counter = 0
        print('Best model saved')
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print(f'Early stopping at epoch: {epoch+1}')
            break
print('Training finished')

### Inference

In [None]:
import torch
import numpy as np
import pandas as pd
import cv2
import albumentations as A
import timm
import os

from albumentations.pytorch import ToTensorV2
from PIL import Image
from tqdm.auto import notebook_tqdm
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder

In [None]:
model_path = "./weights/maxvit-v1-best.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, mode='train', transform=None):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.mode = mode

        if mode == 'train':
            self.label_encoder = LabelEncoder()
            self.df['label'] = self.label_encoder.fit_transform(self.df['label'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.df.iloc[idx, 0])
        image = Image.open(img_name)
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
            
        label = self.df.iloc[idx, 2] if self.mode == 'train' else -1
        
        return image, label

In [None]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

test_transform = A.Compose([
    A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

In [None]:
test_dataset = CustomDataset(csv_file="./data/test.csv", root_dir="./data/test/", mode='test', transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

model = timm.create_model("timm/maxvit_large_tf_224.in1k", pretrained=False, num_classes=25)
model.to(device)
model = torch.nn.DataParallel(model)

In [None]:
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

def test(model, dataloder):
    predictions = []
    model.eval()

    tqdm_bar = notebook_tqdm(dataloder, desc="Predicting")
    for images, _ in tqdm_bar:
        images = images.to(device)

        with torch.no_grad():
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            predictions.extend(preds.cpu().numpy())
    return predictions

In [None]:
test_predictions = test(model, test_loader)

le = LabelEncoder()

train_df = pd.read_csv("./data/train_df.csv")
le.fit(train_df["label"])

final_prediction = le.inverse_transform(test_predictions)

In [None]:
submission_df = pd.read_csv("./data/sample_submission.csv")
submission_df['label'] = final_prediction
submission_df.to_csv("./answer.csv", index=False)