In [None]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision import io
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision.models import resnet50, ResNet50_Weights


In [None]:
trnsfrms = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor()
    ]
)

train_dataset = torchvision.datasets.ImageFolder(
    'seg_train\seg_train',
    transform=trnsfrms
)
valid_dataset = torchvision.datasets.ImageFolder(
    'seg_test\seg_test',
    transform=trnsfrms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

In [None]:
model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(2048, 6)
device = 'cuda'
idx2class= {j: i for i, j in train_dataset.class_to_idx.items()}


def compute_batch_accuracy(preds, labels):
    preds_classes = torch.argmax(preds, dim=1)  # Get the class with the highest probability
    correct_predictions = (preds_classes == labels).sum().item()
    accuracy = correct_predictions / len(labels)
    return accuracy

for param in model.parameters():
    param.requires_grad = False

model.fc.weight.requires_grad = True
model.fc.bias.requires_grad = True

model.to(device);

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def fit(model: torch.nn.modules.container.Sequential, n_epochs: int, optimizer: torch.optim.Optimizer) -> tuple:
  train_epoch_acc = []
  train_epoch_losses = []
  valid_epoch_losses = []
  valid_epoch_acc =[]
  for epoch in range(n_epochs):
      loss_batch = []
      acc_batch  = []
      model.train()
      for images, labels in train_loader:
          labels = labels.type(torch.LongTensor)
          images = images.to(device)
          labels = labels.to(device)
          preds = model(images)
          loss = criterion(preds, labels)
          loss_batch.append(loss.item())
          acc_batch.append(compute_batch_accuracy(preds, labels))
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
      train_epoch_losses.append(np.mean(loss_batch))
      train_epoch_acc.append(np.mean(acc_batch))
      model.eval()
      loss_batch = []
      acc_batch  = []
      for images, labels in valid_loader:
          labels = labels.type(torch.LongTensor)
          images = images.to(device)
          labels = labels.to(device)
          preds = model(images)
          loss_batch.append(loss.item())
          acc_batch.append(compute_batch_accuracy(preds, labels))
      valid_epoch_losses.append(np.mean(loss_batch))
      valid_epoch_acc.append(np.mean(acc_batch))
      print(f'[Epoch {epoch:02d}] Train loss: {train_epoch_losses[-1]:.4f}, valid loss = {valid_epoch_losses[-1]:.4f} Train acc {train_epoch_acc[-1]:.4f} Valid acc {valid_epoch_acc[-1]:.4f}')
  return train_epoch_acc, train_epoch_losses, valid_epoch_losses, valid_epoch_acc


train_epoch_acc, train_epoch_losses, valid_epoch_losses, valid_epoch_acc = fit(model, 10, optimizer)


In [None]:
def get_prediction(path: str) -> str:
    resize = T.Resize((224, 224))
    img = resize(io.read_image(path)/255)
    model.eval()
    softmax_values = torch.softmax(model(img.unsqueeze(0).to(device)), dim=1)
    predicted_class_index = torch.argmax(softmax_values, dim=1)
    predicted_class = idx2class[predicted_class_index.item()]
    return predicted_class

get_prediction('seg_pred\\144.jpg')


In [None]:
torch.save(model.state_dict(), 'savemodel.pt')