In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Training and validation datasets
train_dataset = datasets.ImageFolder(root=r"D:\Download\DatasetResnet\train", transform=transform)
val_dataset = datasets.ImageFolder(root=r"D:\Download\DatasetResnet\val", transform=transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [2]:
from torchvision import models
import torch.nn as nn

# Load a pre-trained ResNet50 model
model = models.resnet50(pretrained=True)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace the last fully connected layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 4)

# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)




In [3]:
from torch import optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)


In [4]:
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=25):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}')

        # Validate the model
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)

        val_loss /= len(val_loader.dataset)
        print(f'Epoch {epoch+1}/{epochs}, Val Loss: {val_loss:.4f}')

train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10)


Epoch 1/10, Train Loss: 0.5873
Epoch 1/10, Val Loss: 0.2531
Epoch 2/10, Train Loss: 0.1859
Epoch 2/10, Val Loss: 0.1711
Epoch 3/10, Train Loss: 0.1342
Epoch 3/10, Val Loss: 0.1568
Epoch 4/10, Train Loss: 0.1247
Epoch 4/10, Val Loss: 0.1612
Epoch 5/10, Train Loss: 0.0916
Epoch 5/10, Val Loss: 0.1116
Epoch 6/10, Train Loss: 0.0708
Epoch 6/10, Val Loss: 0.1041
Epoch 7/10, Train Loss: 0.0736
Epoch 7/10, Val Loss: 0.1140
Epoch 8/10, Train Loss: 0.0616
Epoch 8/10, Val Loss: 0.0980
Epoch 9/10, Train Loss: 0.0569
Epoch 9/10, Val Loss: 0.0972
Epoch 10/10, Train Loss: 0.0522
Epoch 10/10, Val Loss: 0.0989


In [5]:
torch.save(model.state_dict(), 'trained_resnet50.pth')


In [7]:
from PIL import Image

def predict(image_path, model, device, transform):
    model.eval()
    image = Image.open(image_path)
    image = transform(image).unsqueeze(32)  # Add batch dimension
    image = image.to(device)

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

    return predicted.item()


In [13]:
def predict(image_path, model, device, transform):
    model.eval()
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(device)

    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)

    return predicted.item()

# Example usage
# Assuming you've defined 'transform' and loaded 'model' and 'device' as before
class_names = ['buffalo', 'elephant', 'rhino', 'zebra']
image_path = r"D:\Download\Equus_quagga_burchellii_-_Etosha,_2014.jpg"
predicted_class_index = predict(image_path, model, device, transform)
print(f'The predicted class is {class_names[predicted_class_index]}')



The predicted class is rhino
