In [None]:
import matplotlib.pyplot as plt
import pandas as pd 
import numpy as np 
from PIL import Image

import random
import timm
import timm.optim
import timm.scheduler
from timm.data import ImageDataset, create_dataset, create_loader
from timm.data.transforms_factory import create_transform
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchinfo import summary
from fastai.vision.all import *
from fastai.vision.data import ImageDataLoaders
from fastai.metrics import accuracy, F1Score
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split

In [None]:
image_size = 224
batch = 32

train_df = pd.read_csv('/kaggle/input/hackathon-online-cloud-recognition/train.csv')
# train_df['id'] = train_df['id'].apply(lambda x : "/home/dip_21/project/cloud/images/train/" + x)

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(train_df['id'], train_df['label'], 
                                                      test_size=0.2, 
                                                      random_state=42, 
                                                      stratify=train_df[['label']])

In [None]:
train_df['is_valid'] = train_df.index.isin(list(y_valid.index))
train_df

In [None]:
dls = ImageDataLoaders.from_df(train_df,
                                    path='/kaggle/input/hackathon-online-cloud-recognition/images/train',
                                    valid_col='is_valid',
                                    bs = batch ,
                                    item_tfms=[ToTensor(),Resize(image_size)] ,
                                    batch_tfms=[*aug_transforms(do_flip=False,
                                                                flip_vert=True,
                                                                max_rotate=360,
                                                                p_affine=0.8,
                                                                max_warp=0.2),
                                                Normalize.from_stats(*imagenet_stats),
                                               ],
                                    seed = 123,
                                    )

dls.train.show_batch(max_n=30)

In [None]:
timm.list_models('*vit_base_patch16_224*',pretrained=True)

In [None]:
# loss_func2 = CrossEntropyLossFlat(weight=class_weights)
save_cb = SaveModelCallback(monitor='valid_loss')

# Create a list of callbacks
callbacks = [save_cb] 
model_name = "vit_base_patch16_224.orig_in21k"

In [None]:
learn = vision_learner(dls, model_name,
                       path='/kaggle/working/vit/',
                       cbs=[ShowGraphCallback()] ,
                       metrics=[accuracy])  # metrics=[accuracy]
                     #    #,WandbCallback()     force_download=True, 
learn.to_fp16()
# learn.model = torch.nn.DataParallel(learn.model)

In [None]:
learn.fine_tune(10,cbs=callbacks)

In [None]:
gpu_memory_info = torch.cuda.memory_summary(device=None, abbreviated=False)
print(gpu_memory_info)

In [None]:
learn.export('/kaggle/working/vit.pkl')
learn.validate()
learn.show_results()

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
interp.print_classification_report()

In [None]:
interp.plot_top_losses(10)