In [None]:
import torch
import torchvision
import os
import matplotlib.pyplot as plt
import pandas as pd
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from torch import nn
from torchvision import transforms
try:
  from torchinfo import summary
except ModuleNotFoundError as e:
  print(f"{e}, Downloading..")
  !pip install torchinfo
  from torchinfo import summary

In [None]:
from modules import data_setup, model, train
model, train_transform, test_transform = model.create_model(10000)
train_dataloader, test_dataloader, train_data = data_setup.create_dataloaders(train_transforms=train_transform,
                                                                                  test_transforms=test_transform)

In [None]:
summary(model=model,
        input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

In [None]:
data = [i.split("_") for i in os.listdir("./data/train/2021_train_mini")]
cols = ["label","kingdom", "phylum", "class", "order", "family", "genus", "name"]
df = pd.DataFrame(data=data, columns=cols)
df.sort_values("label", inplace=True)
df.set_index("label", drop=True, inplace=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
torch.compile(model)
EPOCHS = 50

In [None]:
results = train.train_model(model,
                            train_dataloader,
                            test_dataloader,
                            loss_fn,
                            optimizer,
                            EPOCHS,
                            device)

In [None]:
val = next(iter(test_dataloader))


In [None]:
import random
random_img = random.randint(0, 127)
model.eval()
with torch.inference_mode():
  img = val[0][random_img]
  img_converted = img.unsqueeze(dim=0)
  img_converted = model(img_converted.to(device))
  pred_label = torch.argmax(torch.softmax(img_converted, dim=1), dim=1)
  plt.imshow(img.cpu().permute(1, 2, 0));
  plt.title(f"Prediction Label: {df.iloc[pred_label.max()]} | True Label: {df.iloc[random_img]}")
  plt.axis("off");