In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt

from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
import os
import random

import torch. multiprocessing as mp
mp.set_start_method('spawn')

In [None]:
data_path = Path("./data")
if data_path.is_dir():
  !rm -fd data/*/*/* data/*/* data/* data
data_path.mkdir(parents=True, exist_ok=True)

In [None]:
!mkdir modules

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from modules import model, data 
effnetb2_v2_m, train_transforms, test_transforms = model.create_effnetb2_v2_m(102)
train_dataloader, test_dataloader, val_dataloader = data.create_dataloaders(root=data_path,
                                                                            train_transforms=train_transforms,
                                                                            test_transforms=test_transforms,
                                                                            batch_size=128,
                                                                            device="cpu")


In [None]:

img_path = Path("./data/flowers-102/jpg/")
random_img_paths = random.sample(list(os.listdir(img_path)), k=5)
fig, ax = plt.subplots(nrows=5, ncols=1, figsize=(18,12))
for idx, pth in enumerate(random_img_paths):
  img = Image.open(img_path/pth)
  ax[idx].imshow(img)
  ax[idx].axis("off");

In [None]:
effnetb2_v2_m.to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(effnetb2_v2_m.parameters(), lr=1e-3)
EPOCHS=30
compile_model = torch.compile(effnetb2_v2_m)

In [None]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
from modules import train
results = train.train_model(model=compile_model,
                      train_dataloader=train_dataloader,
                      test_dataloader=test_dataloader,
                      optimizer=optimizer,
                      loss_fn=loss_fn,
                      device=device,
                      epochs=EPOCHS)

In [None]:
import pandas as pd 
results_df = pd.DataFrame(results) 
results_df.plot();

In [None]:
torch.save(obj=effnetb2_v2_m.state_dict(),f="./flower102_effnetb2_v2_m.pth")

In [None]:
val = next(iter(val_dataloader))

In [None]:
!wget "https://gist.githubusercontent.com/JosephKJ/94c7728ed1a8e0cd87fe6a029769cde1/raw/403325f5110cb0f3099734c5edb9f457539c77e9/Oxford-102_Flower_dataset_labels.txt"

In [None]:
with open("Oxford-102_Flower_dataset_labels.txt", "r") as f:
  class_names= [name for name in f.readlines()]

In [None]:
random_img = random.randint(0, 127)
effnetb2_v2_m.eval()
with torch.inference_mode():
  img = val[0][random_img]
  img_converted = img.unsqueeze(dim=0)
  img_converted = effnetb2_v2_m(img_converted.to(device))
  pred_label = torch.argmax(torch.softmax(img_converted, dim=1), dim=1)
  plt.imshow(img.cpu().permute(1, 2, 0));
  plt.title(f"Prediction Label: {class_names[pred_label.max()]} | Label: {class_names[val[1][random_img]]}")
  plt.axis("off");