In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print ("using device: ",device)

using device:  cuda


In [3]:
DATA_DIR = "./JellyFish"

In [34]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


In [35]:
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)

In [36]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [37]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [38]:
class_names = full_dataset.classes
print("Classes:", class_names)

Classes: ['Moon_jellyfish', 'barrel_jellyfish', 'blue_jellyfish', 'compass_jellyfish', 'lions_mane_jellyfish', 'mauve_stinger_jellyfish']


In [39]:
class CNNModel(nn.Module) : 
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 32 * 32, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
num_classes=len(class_names)
model = CNNModel(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [21]:
num_epochs = 10
train_loss_values = []
val_loss_values = []

In [24]:
def train_model():
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        train_loss_values.append(train_loss)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        val_loss_values.append(val_loss)
        val_acc = 100 * correct / total
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    torch.save(model.state_dict(), "jellyfish_classifier.pth")

In [25]:
train_model()

Epoch 1/10, Train Loss: 0.0002, Val Loss: 1.6920, Val Acc: 71.11%
Epoch 2/10, Train Loss: 0.0002, Val Loss: 1.7006, Val Acc: 71.11%
Epoch 3/10, Train Loss: 0.0002, Val Loss: 1.7059, Val Acc: 71.11%
Epoch 4/10, Train Loss: 0.0002, Val Loss: 1.7107, Val Acc: 71.11%
Epoch 5/10, Train Loss: 0.0002, Val Loss: 1.7191, Val Acc: 71.11%
Epoch 6/10, Train Loss: 0.0002, Val Loss: 1.7221, Val Acc: 71.11%
Epoch 7/10, Train Loss: 0.0002, Val Loss: 1.7331, Val Acc: 71.11%
Epoch 8/10, Train Loss: 0.0002, Val Loss: 1.7355, Val Acc: 71.11%
Epoch 9/10, Train Loss: 0.0001, Val Loss: 1.7415, Val Acc: 71.11%
Epoch 10/10, Train Loss: 0.0001, Val Loss: 1.7427, Val Acc: 71.11%


In [28]:
def load_model():
    model = CNNModel(num_classes)
    model.load_state_dict(torch.load("jellyfish_classifier.pth"))
    model.to(device)
    model.eval()
    return model

model = load_model()

  model.load_state_dict(torch.load("jellyfish_classifier.pth"))


In [29]:
from PIL import Image

In [30]:
def predict_image(image_path, model):
    image = Image.open(image_path)
    image = transform(image).convert("RGB")
    image = image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
    return class_names[predicted.item()]


In [33]:
sample_image = "./JellyFish/lions_mane_jellyfish/29.jpg"
prediction = predict_image(sample_image, model)
print("Predicted Class:", prediction)


Predicted Class: compass_jellyfish
