In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import os
from torchvision import models

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
DATA_DIR = "./jellyfish_images"


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])

In [5]:
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)


In [6]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [7]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [8]:
class_names = full_dataset.classes
print("Classes:", class_names)

Classes: ['Moon_jellyfish', 'barrel_jellyfish', 'blue_jellyfish', 'compass_jellyfish', 'lions_mane_jellyfish', 'mauve_stinger_jellyfish']


In [9]:
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features  
model.fc = nn.Linear(num_features, len(class_names)) 
model = model.to(device)



In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

In [11]:
num_epochs = 10
train_loss_values = []
val_loss_values = []

In [12]:
def train_model():
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_loss_values.append(train_loss)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        val_loss_values.append(val_loss)
        val_acc = 100 * correct / total
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    torch.save(model.state_dict(), "jellyfish_classifier_resnet18.pth")

In [13]:
train_model()

Epoch 1/10, Train Loss: 1.7997, Val Loss: 1.4408, Val Acc: 51.11%
Epoch 2/10, Train Loss: 1.2657, Val Loss: 1.1647, Val Acc: 59.44%
Epoch 3/10, Train Loss: 1.0228, Val Loss: 1.0347, Val Acc: 68.89%
Epoch 4/10, Train Loss: 0.9010, Val Loss: 0.8369, Val Acc: 74.44%
Epoch 5/10, Train Loss: 0.7680, Val Loss: 0.7900, Val Acc: 75.00%
Epoch 6/10, Train Loss: 0.7032, Val Loss: 0.7349, Val Acc: 78.33%
Epoch 7/10, Train Loss: 0.6364, Val Loss: 0.7168, Val Acc: 77.78%
Epoch 8/10, Train Loss: 0.5932, Val Loss: 0.6632, Val Acc: 77.78%
Epoch 9/10, Train Loss: 0.5447, Val Loss: 0.6342, Val Acc: 80.00%
Epoch 10/10, Train Loss: 0.5193, Val Loss: 0.6231, Val Acc: 81.11%


In [14]:
def load_model():
    model = models.resnet18()
    model.fc = nn.Linear(num_features, len(class_names))
    model.load_state_dict(torch.load("jellyfish_classifier_resnet18.pth"))
    model.to(device)
    model.eval()
    return model

In [15]:
from PIL import Image

In [16]:
def predict_image(image_path, model):
    image = Image.open(image_path).convert("RGB")  
    image = transform(image)  
    image = image.unsqueeze(0).to(device) 

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    
    return class_names[predicted.item()]

In [17]:
sample_image = "./jellyfish_images/lions_mane_jellyfish/29.jpg"
prediction = predict_image(sample_image, model)
print("Predicted Class:", prediction)

Predicted Class: lions_mane_jellyfish
