In [None]:
import numpy as np
import pandas as pd

In [None]:
import torch
from datasets import (Dataset,
                      concatenate_datasets,
                      Image,
                      ClassLabel,)

from PIL import Image as pil

from transformers import (Swinv2Config,
                          AutoImageProcessor,
                          AutoModelForImageClassification,
                          Trainer,
                          TrainingArguments,)

from torchvision.transforms import (
                            Compose,
                            RandomResizedCrop,
                            CenterCrop,
                            RandomRotation,
                            ColorJitter,
                            Normalize,
                            Resize,
                            ToTensor,
                            ToPILImage,)

import evaluate

### Load and augment image dataset

In [None]:
%run CDCC_Image_Classification_Func.ipynb

In [None]:
labels_path = '/Zindi_Crop_Classification/Data/Image_MetaData/{}'

image_path = '/Zindi_Crop_Classification/Data/Raw_Images/{}'
aug_image_path = '/Zindi_Crop_Classification/Data/Augmented_Images/{}'

model_path = '/Zindi_Crop_Classification/Fine_Tuned_Models/{}'

#### Exploratory Data Analysis

In [None]:
mstr_frame = pd.read_csv(labels_path.format('Train.csv'))
print(mstr_frame.shape)
mstr_frame.head()

In [None]:
# Unbalanced data set
mstr_frame['damage'].value_counts()

# damage
# G        11623
# WD        9238
# DR        4516
# other      419
# ND         272
# Name: count, dtype: int64

In [None]:
class_labels = mstr_frame['damage'].value_counts().index.to_list()
class_labels

### Data augmentation (optional)
<b> Generate additional images for the under-represented classes <br>
<b> Consider geometric transformations, photometric transformations, or a combination thereof <br>

In [None]:
rng = np.random.default_rng(200)

In [None]:
# Generate training dataset
trng_idx = select_indexes(5000, 0.7, mstr_frame)
subset_labels = mstr_frame.iloc[trng_idx, :].copy()
subset_labels.shape

In [None]:
# Generate validation dataset
non_selected_idx = list(set(range(mstr_frame.shape[0])) - set(trng_idx))
val_labels = mstr_frame.iloc[non_selected_idx, :].copy()

val_idx = select_indexes(2000, 1, val_labels)
subset_eval_labels = val_labels.iloc[val_idx, :].copy()
subset_eval_labels.shape

In [None]:
aug_img_size = 256 # Needed for RandomResizedCrop

aug_transform = Compose(
                [
                    #RandomResizedCrop((aug_img_size, aug_img_size)),
                    #ColorJitter((0.5, 2)),
                    #ColorJitter(contrast=(0.5, 2)),
                    RandomRotation(30),
                    ToTensor(),
                ]
            )

topilImage = ToPILImage()

In [None]:
trng_dataset = generate_dataset(subset_labels, 
                                0, 
                                0, # Change this value if data augmentation needs to be performed, e.g. 1500
                                aug_image_path.format('other/other_'),
                                aug_image_path.format('ND/ND_'),
                                True,
                                'Trng',
                               )

In [None]:
val_dataset = generate_dataset(subset_eval_labels, 
                                0, # change this value if validation set should include augmented images, e.g. 1500
                                0, # if the previous value is non-zero, then this should be previous value + number of images, e,g. 2000
                                aug_image_path.format('other/other_'),
                                aug_image_path.format('ND/ND_'),
                                True,
                                'Val',
                              )

In [None]:
data_labels = trng_dataset.features['label'].names
label2id, id2label = dict(), dict()
for i, label in enumerate(data_labels):
    label2id[label] = i
    id2label[i] = label

In [None]:
splits = trng_dataset.train_test_split(test_size=0.1)
train_ds = splits['train']

In [None]:
val_splits = val_dataset.train_test_split(test_size=0.1)
val_ds = val_splits['train']

In [None]:
metric = evaluate.load("accuracy")

### Model training and validation
<b> Step 1: Resize all images to a fixed size <br>
<b> Step 2: Initialize the model's parameters using the pretrained SwinV2 model

In [None]:
model_name = "microsoft/swinv2-tiny-patch4-window16-256"
batch_size = 12 #4 # batch size for training and evaluation

image_processor = AutoImageProcessor.from_pretrained(model_name)

In [None]:
crop_size = 256 
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

img_transform = Compose(
        [
            Resize(crop_size),
            ToTensor(), # scales an image so that each channel has values in the range [0, 1]
            normalize,
        ]
    )

In [None]:
# Create 'pixel_values' column on-the-fly
train_ds.set_transform(preprocess_img)
val_ds.set_transform(preprocess_img)

In [None]:
# Create a configuration with optional regularization
swin_config = Swinv2Config(#image_size=center_crop_size,
                        image_size=crop_size,
                        #hidden_dropout_prob=0.1, 
                        #attention_probs_dropout_prob=0.1,
                        label2id=label2id, # this is required to change the number of nodes in the output layer
                        id2label=id2label, # this is required to change the number of nodes in the output layer
                        )

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    ignore_mismatched_sizes = True, # this is required to change the number of nodes in the output layer
    config=swin_config,
)

In [None]:
model_prefix = model_name.split("/")[-1] + '-CDCC'

# Vary the learning rate and determine its impact on model's performance
args = TrainingArguments(
    model_path.format(model_prefix),
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5, #5e-5
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=40,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=img_collate_fn,
)

In [None]:
train_results = trainer.train()
trainer.save_model()

### Prediction on Zindi's test data set

In [None]:
# Copy the best model obtained at the end of model training to 'Best_Model' folder

ck_model_name = model_path.format('Best_Model')
best_processor = AutoImageProcessor.from_pretrained(ck_model_name)
best_model = AutoModelForImageClassification.from_pretrained(ck_model_name)
best_processor.size = {'height': 256, 'width': 256}

In [None]:
test_frame = pd.read_csv(labels_path.format('Test.csv'))
print(test_frame.shape)
test_frame.head()

In [None]:
num_test_samples = test_frame.shape[0]
test_logit = np.zeros(num_test_samples*5).reshape(-1, 5)

for i in range(num_test_samples):
    if i% 1000 == 0:
        print(f'Test image number {i}')
    test_img = pil.open(image_path.format(test_frame.iloc[i, 1])).convert('RGB')
    encoding = best_processor(test_img, return_tensors="pt")
    out = best_model(**encoding)
    test_logit[i, :] = out.logits.numpy(force=True)

In [None]:
class_labels
# ['G', 'WD', 'DR', 'other', 'ND']

In [None]:
# order for submission - DR,G,ND,WD,other
out_frame = pd.DataFrame(test_logit, columns=class_labels)
out_frame['ID'] = test_frame['ID']
sorted_frame = out_frame[['ID', 'DR','G', 'ND', 'WD', 'other']].copy()
sorted_frame.to_csv('CDCC_Inference_2.csv', index=False, sep=',')

In [None]:
# Convert logits to probabilities
pred_logits = pd.read_csv('CDCC_Inference_2.csv')
softmax_prob = softmax(pred_logits.iloc[:, 1:], axis=1)
pred_logits.iloc[:, 1:] = softmax_prob # Overwrite the logits with probabilities
pred_logits.to_csv('CDCC_Inference_Prob_2.csv', index=False, sep=",")