In [1]:
import os
import shutil
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ResNetForImageClassification, ResNetConfig
from tqdm import tqdm

In [2]:
data_dir = '/kaggle/input/identifying-different-breeds-of-snakes/dataset'

In [3]:
class SnakeDataset(Dataset):
    def __init__(self, data_dir, classes, transform=None):
        self.data_dir = data_dir
        self.classes = classes
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        for class_name in classes:
            class_dir = os.path.join(data_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(classes.index(class_name))
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [4]:
classes = [
    "agkistrodon-contortrix", "agkistrodon-piscivorus", "coluber-constrictor", 
    "crotalus-atrox", "crotalus-horridus", "crotalus-ruber", "crotalus-scutulatus", 
    "crotalus-viridis", "diadophis-punctatus", "haldea-striatula", "heterodon-platirhinos", 
    "lampropeltis-californiae", "lampropeltis-triangulum", "masticophis-flagellum", 
    "natrix-natrix", "nerodia-erythrogaster", "nerodia-fasciata", "nerodia-rhombifer", 
    "nerodia-sipedon", "opheodrys-aestivus", "pantherophis-alleghaniensis", 
    "pantherophis-emoryi", "pantherophis-guttatus", "pantherophis-obsoletus", 
    "pantherophis-spiloides", "pantherophis-vulpinus", "pituophis-catenifer", 
    "rhinocheilus-lecontei", "storeria-dekayi", "storeria-occipitomaculata", 
    "thamnophis-elegans", "thamnophis-marcianus", "thamnophis-proximus", 
    "thamnophis-radix", "thamnophis-sirtalis"
]

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
dataset = SnakeDataset(data_dir='/kaggle/input/identifying-different-breeds-of-snakes/dataset',classes=classes, transform=transform)

In [7]:
train_indices, val_indices = train_test_split(list(range(len(dataset))), test_size=0.2, random_state=42)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

In [8]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [9]:
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')

config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

In [10]:
last_layer = model.classifier[-1]  # This assumes the last layer is a Linear layer

# Check if the last layer is indeed a Linear layer
if isinstance(last_layer, nn.Linear):
    # Replace the last layer with a new Linear layer
    model.classifier[-1] = nn.Linear(last_layer.in_features, len(classes))
else:
    raise ValueError("The last layer of the classifier is not a Linear layer.")

In [11]:
model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(classes))

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
import matplotlib.pyplot as plt

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=12):
    train_losses = []
    val_losses = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
        
        # Validation
        model.eval()
        val_loss = 0.0
        corrects = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                outputs = model(inputs).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, preds = torch.max(outputs, 1)
                corrects += torch.sum(preds == labels.data)
                total += labels.size(0)
        
        val_loss = val_loss / len(val_loader)
        val_acc = corrects.double() / total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc.item())
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')
    
    # Plotting the graphs
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()
    
    plt.show()

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=12)

In [None]:
from PIL import Image
import torch
from torchvision import transforms

def test_model_with_image(model, image_path, device):
    # Define the same transformations used during training
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)
    
    # Set the model to evaluation mode
    model.eval()
    
    # Get the model's prediction
    with torch.no_grad():
        outputs = model(image).logits
        _, preds = torch.max(outputs, 1)
    
    # Convert the prediction to the class name
    predicted_class = classes[preds.item()]
    
    print(f'The predicted class is: {predicted_class}')

# Example usage
image_path = '/kaggle/input/identifying-different-breeds-of-snakes/dataset/heterodon-platirhinos/0154990cb5.jpg'  # Replace with your image path
test_model_with_image(model, image_path, device)

In [None]:
print (model)


In [None]:
torch.save(model.state_dict(),"resnet50_snake_identification_model.pth")

In [None]:
!ls "/kaggle/working"

In [None]:
from IPython.display import FileLink
FileLink(r'./resnet50_snake_identification_model.pth')