## Dependencies

In [None]:
%matplotlib inline
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import torch
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm

## Dataset processing

In [None]:
INPUT_PATH = '/kaggle/input/classify-leaves/'
OUTPUT_PATH = '/kaggle/working/'

In [None]:
# train data

train_data = pd.read_csv(INPUT_PATH + 'train.csv')

train_img_paths = train_data.iloc[:, 0].values
train_img_labels = train_data.iloc[:, 1].values
num_labels = len(np.unique(train_img_labels))
print("Total label:", num_labels)

### Label map

In [None]:
# Label distribution
def get_class_map(labels):
    # sort by frequency
    unique_labels, counts = np.unique(labels, return_counts=True)
    # sorted_index = np.argsort(counts)
    # unique_labels = unique_labels[sorted_index]

    class2num_map: dict[str, int] = dict(zip(unique_labels, range(len(unique_labels))))
    num2class_map: dict[int, str] = dict(zip(range(len(unique_labels)), unique_labels))

    return class2num_map, num2class_map

class2num, num2class = get_class_map(train_img_labels)

### Test input shape

In [None]:
def extract_tensor(index):
    path = train_data.iloc[index, 0]
    return torch.tensor(plt.imread(INPUT_PATH + path), dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)

print("Input shape:", extract_tensor(0).shape)

### Define dataset

In [None]:
class LeafDataset(Dataset):
    def __init__(self, csv_path, resize_height, resize_width, valid_ratio=0.2, mode="train"):
        self.csv_path = csv_path
        self.resize_height = resize_height
        self.resize_width = resize_width

        self.valid_ratio = valid_ratio
        self.mode = mode

        self.data_info = pd.read_csv(self.csv_path)
        # self.img_paths = self.data_info.iloc[:, 0].values

        self.data_len = len(self.data_info.index)
        self.train_len = int(self.data_len * (1 - self.valid_ratio))

        def split_data(data: pd.DataFrame, train_len):
            unique_labels, counts = np.unique(data.iloc[:, 1].values, return_counts=True)
            # train data should include all classes, if one label is less than 2, then it should be in train data
            valid_labels_pool = unique_labels[counts > 1]
            valid_data_pool = data[data.iloc[:, 1].isin(valid_labels_pool)]
            valid_index_pool = valid_data_pool.index
            valid_index: list[int] = np.random.choice(valid_index_pool, size=train_len)
            train_index: list[int] = np.setdiff1d(data.index, valid_index)
            return data.iloc[train_index], data.iloc[valid_index]

        if self.mode == "train":
            train_data_info, _ = split_data(self.data_info, self.train_len)

            self.img_paths = train_data_info.iloc[:, 0].values
            self.img_labels = train_data_info.iloc[:, 1].values
        elif self.mode == "valid":
            _, valid_data_info = split_data(self.data_info, self.train_len)

            self.img_paths = valid_data_info.iloc[:, 0].values
            self.img_labels = valid_data_info.iloc[:, 1].values

        elif self.mode == "test":
            self.img_paths = self.data_info.iloc[:, 0].values
            self.img_labels = ["None" for _ in range(self.data_len)]  # no label for initialization
        else:
            raise ValueError("Unknown mode")

        self.real_len = len(self.img_paths)  # real length of dataset

    def __len__(self):
        return self.real_len

    def __getitem__(self, idx) -> tuple[torch.Tensor, int | str]:
        img_path = INPUT_PATH + self.img_paths[idx]
        img_label = self.img_labels[idx]
        # ori_img_tensor = torch.tensor(plt.imread(img_path), dtype=torch.float32).permute(2, 0, 1).unsqueeze()

        if self.mode == "train" or self.mode == "valid":
            transform = transforms.Compose([
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.Resize((self.resize_height, self.resize_width)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.299, 0.224, 0.225])
            ])
            img_tensor: torch.Tensor = transform(Image.open(img_path))

            label_int = class2num[img_label]
            return img_tensor, label_int

        elif self.mode == "test":
            transform = transforms.Compose([
                transforms.Resize((self.resize_height, self.resize_width)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.299, 0.224, 0.225])
            ])
            img_tensor: torch.Tensor = transform(Image.open(img_path))

            if type(img_label) == str: # return label as string
                return img_tensor, img_label
            elif type(img_label) == int:
                return img_tensor, num2class[img_label]
            else:
                return img_tensor, img_label

        else:
            raise ValueError("Unknown mode")


train_set = LeafDataset(csv_path=INPUT_PATH+"train.csv", resize_height=224, resize_width=224, mode="train")
valid_set = LeafDataset(csv_path=INPUT_PATH+"train.csv", resize_height=224, resize_width=224, mode="valid")
test_set = LeafDataset(csv_path=INPUT_PATH+"test.csv", resize_height=224, resize_width=224, mode="test")

## Model

### Define model (ResNet)

In [None]:
IN_CHANNELS = 3
NUM_LABELS = 176

def get_model_custom(): # ResNet18
    # one residual block
    class Residual(nn.Module):
        def __init__(self, input_channels, num_channels, strides=1, use_1x1conv=False):
            super().__init__()

            self.conv1 = nn.Conv2d(input_channels, num_channels,
                                kernel_size=3, padding=1, stride=strides)
            self.conv2 = nn.Conv2d(num_channels, num_channels,
                                    kernel_size=3, padding=1)

            if use_1x1conv:
                # 1x1 conv: new_X = Y + conv(X)
                self.res_conv = nn.Conv2d(input_channels, num_channels,
                                        kernel_size=1, stride=strides)

            self.bn1 = nn.BatchNorm2d(num_channels)
            self.bn2 = nn.BatchNorm2d(num_channels)

        def forward(self, X):
            Y = F.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if hasattr(self, 'res_conv'):
                X = self.res_conv(X) # 1x1 conv
            Y += X
            return F.relu(Y)

    # ResNet block module
    def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(input_channels, num_channels, strides=2, use_1x1conv=True))
            else:
                blk.append(Residual(num_channels, num_channels))
        return blk

    # ResNet (18)
    b1 = nn.Sequential(nn.Conv2d(IN_CHANNELS, 64, kernel_size=7, stride=2, padding=3),
                    nn.BatchNorm2d(64), nn.ReLU(),
                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True), nn.Dropout())
    b3 = nn.Sequential(*resnet_block(64, 128, 2), nn.Dropout())
    b4 = nn.Sequential(*resnet_block(128, 256, 2), nn.Dropout())
    b5 = nn.Sequential(*resnet_block(256, 512, 2), nn.Dropout())

    net = nn.Sequential(b1, b2, b3, b4, b5,
                        nn.AdaptiveAvgPool2d((1,1)),
                        nn.Flatten(),
                        nn.Linear(512, NUM_LABELS))

    return net

def get_model_ResNet18():
    net = models.resnet18(pretrained=True)
    net.fc = nn.Sequential(nn.Linear(net.fc.in_features, NUM_LABELS))
    return net

def get_model_ResNet50():
    net = models.resnext50_32x4d(pretrained=True) # resnext50_32x4d
    net.fc = nn.Sequential(nn.Linear(net.fc.in_features, NUM_LABELS))
    return net

def get_model(name):
    if name == "ResNet18":
        return get_model_ResNet18()
    elif name == "ResNet50":
        return get_model_ResNet50()
    elif name == "custom":
        return get_model_custom()
    else:
        raise ValueError("Unknown model name")

## Train

### Hyperparameters

In [None]:
BATCH_SIZE = 16
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.008

In [None]:
print("Batch size:", BATCH_SIZE)
print("Number of epochs:", NUM_EPOCHS)
print("Learning rate:", LEARNING_RATE)
print("Weight decay:", WEIGHT_DECAY)

### Training

In [None]:
# Get device if CUDA is available
def get_device(device='auto'):
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    elif device != 'cpu' and device != 'cuda':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device)
    return device


# Randomly initialize weights
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)  # initialize weights


def train(model, train_set, valid_set, batch_size, num_epochs, lr, weight_decay, device='auto'):
    device = get_device(device)

    print("Training via %s..." % device.type.upper())

    train_iter = DataLoader(train_set, batch_size=batch_size, num_workers=5, shuffle=False)
    valid_iter = DataLoader(valid_set, batch_size=batch_size, num_workers=5, shuffle=False)

    model = model.to(device)
    model.apply(init_weights)  # initialize weights

    loss = nn.CrossEntropyLoss()  # loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)  # optimizer

    # for paint
    epoch_list = []
    y_train_loss_list = []
    y_valid_loss_list = []

    # train
    for epoch in range(num_epochs):
        # train_loss_list: loss of all train batches
        # train_acc_list: accuracy of all train batches
        train_loss_list, train_acc_list = [], []
        # n: total number of train samples
        n = 0
        for batch in tqdm(train_iter):
            X, y = batch
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)

            l = loss(y_hat, y)

            optimizer.zero_grad()
            l.backward()
            optimizer.step()

            train_loss = l.cpu().item()
            # take max as the predicted label
            train_acc = (y_hat.argmax(dim=1) == y).sum().cpu().item()
            train_loss_list.append(train_loss)
            train_acc_list.append(train_acc)

            n += y.shape[0]

        model.eval()

        # valid_loss_list: loss of all validation batches
        # valid_acc_list: accuracy of all validation batches
        valid_acc_list, valid_loss_list = [], []
        # valid_n: total number of validation samples
        valid_n = 0
        for batch in tqdm(valid_iter):
            X, y = batch
            X = X.to(device)
            y = y.to(device)
            y_hat = model(X)

            l = loss(y_hat, y)

            valid_loss = l.cpu().item()
            # take max as the predicted label
            valid_acc = (y_hat.argmax(dim=1) == y).sum().cpu().item()
            valid_loss_list.append(valid_loss)
            valid_acc_list.append(valid_acc)

            valid_n += y.shape[0]

        # paint
        epoch_list.append(epoch + 1)
        train_loss = sum(train_loss_list) / len(train_loss_list)
        train_acc = sum(train_acc_list) / n
        valid_loss = sum(valid_loss_list) / len(valid_loss_list)
        valid_acc = sum(valid_acc_list) / valid_n
        y_train_loss_list.append(train_loss)
        y_valid_loss_list.append(valid_loss)
        if epoch % 1 == 0 or epoch == num_epochs - 1:
            plt.cla()
            plt.title("Loss Curve")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.plot(epoch_list, y_train_loss_list)
            plt.plot(epoch_list, y_valid_loss_list, color="orange")
            plt.legend(["Train Loss", "Valid Loss"])
            # plt.scatter(epoch_list, y_train_loss_list, s=10)
            # plt.scatter(epoch_list, y_valid_loss_list, s=10, color="orange")
            display.clear_output(wait=True)
            plt.pause(0.00000001)

        # output info
        print(
            'Epoch %d, Train Loss %.4f, Train Acc %.3f\n%sValid Loss %.4f, Valid Acc %.3f' %
            (epoch + 1,
             train_loss, train_acc,
             " " * len("Epoch %d, " % (epoch + 1)),
             valid_loss, valid_acc))

    print("Training done")

    # save model
    model_path = OUTPUT_PATH + "model_{}.pth".format(pd.Timestamp.now().strftime("%Y%m%d%H%M%S"))
    torch.save(model.state_dict(), model_path)
    print("Model saved at", model_path)

    return model


model = get_model("custom")

model = train(model, train_set, valid_set, batch_size=BATCH_SIZE,
              num_epochs=NUM_EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

## Predict

In [None]:
# model.load_state_dict(torch.load("model_20240110024206.pth"))

def predict(model, test_set, device='auto'):
    device = get_device(device)
    
    print("Predicting via %s..." % device.type.upper())

    test_iter = DataLoader(test_set, batch_size=1, num_workers=5, shuffle=False)

    model = model.to(device)

    model.eval()

    pred_list = []

    for X, _ in test_iter:
        X = X.to(device)
        y_hat = model(X)
        pred_list.append(y_hat.argmax(dim=1).cpu().item())

    print("Prediction done")

    return pred_list

# Integrate prediction into dataset.img_labels (as int)
def integrate2dataset(dataset: LeafDataset, pred_list: list[int]):
    dataset.img_labels = pred_list
    return dataset

# Convert type of dataset.img_labels to str
def get_str_labels(dataset: LeafDataset):
    str_labels = dataset.img_labels[:]
    if type(dataset.img_labels[0]) == int:
        str_labels = [num2class[i] for i in dataset.img_labels]
    return str_labels

def envelopeDataFrame(dataset: LeafDataset):
    df = pd.DataFrame({"image": dataset.img_paths, "label": get_str_labels(dataset)})
    return df

pred_list = predict(model, test_set)
integrate2dataset(test_set, pred_list)
submission = envelopeDataFrame(test_set)
submission.to_csv(OUTPUT_PATH + "submission.csv", index=False)

submission