In [None]:
import torch 
import torchvision
import matplotlib.pyplot as plt 
import os

from pathlib import Path
from torch import nn 
from torchvision import transforms
from PIL import Image
try:
  from torchinfo import summary 
except ModuleNotFoundError as e:
  print(f"{e}, downloading..")
  !pip install torchinfo
  from torchinfo import summary 

In [None]:
from zipfile import ZipFile
with ZipFile("./Metal_Surface_Defects_Dataset.zip", "r") as z:
  z.extractall(path="./Metal_Surface_Defects")


In [None]:
from torchvision.datasets import ImageFolder

train_path = Path("/content/Metal_Surface_Defects/Metal_Surface_Defects_Dataset/NEU Metal Surface Defects Data/train")
test_path = Path("/content/Metal_Surface_Defects/Metal_Surface_Defects_Dataset/NEU Metal Surface Defects Data/test")

train_path, test_path

In [None]:
from modules import data_setup, model
BATCH_SIZE = 128
NUM_WORKERS = os.cpu_count()
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_path=train_path,
                                                                        test_path=test_path,
                                                                        train_transform=None,
                                                                        test_transform=None,
                                                                        batch_size=BATCH_SIZE,
                                                                        num_workers=NUM_WORKERS)
metal_defects = model.create_model(in_channels=3,
                                   out_channels=len(class_names),
                                   hidden_features=16,
                                   device="cuda")

In [None]:
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(metal_defects.parameters(), lr=1e-3)
EPOCHS = 20

In [None]:
results = {"train_loss": [],
          "train_acc": [],
          "test_loss": [],
          "test_acc": []}
for epoch in range(EPOCHS):
  metal_defects.train()
  train_loss, train_acc = 0, 0
  for batch, (X, y) in enumerate(train_dataloader):
    X, y = X.to(device), y.to(device)
    y_preds = metal_defects(X)
    loss=loss_fn(y_preds, y)
    train_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    y_pred_class = torch.argmax(torch.softmax(y_preds, dim=1), dim=1)
    train_acc += (y_pred_class==y).sum().item()/len(y_preds)
  train_acc = train_acc / len(train_dataloader)
  train_loss = train_loss / len(train_dataloader)

  metal_defects.eval()
  test_loss, test_acc = 0, 0
  with torch.inference_mode():
    for batch, (X, y) in enumerate(test_dataloader):
      X, y = X.to(device), y.to(device)
      test_preds = metal_defects(X)
      test_loss = loss_fn(test_preds, y)
      test_loss += test_loss.item()

      test_pred_labels = torch.argmax(test_preds, dim=1)
      test_acc = (test_pred_labels==y).sum().item()/len(test_preds)
    test_loss = test_loss / len(test_dataloader)
    test_acc = test_acc / len(test_dataloader)
  print(f"Epoch: {epoch+1} | Train Loss: {train_loss:.3f} | Test Loss: {test_loss:.3f} | Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f}")
  results["train_loss"].append(train_loss)
  results["train_acc"].append(train_acc)
  results["test_loss"].append(test_loss)
  results["test_acc"].append(test_acc)
