<a href="https://colab.research.google.com/github/2203a51759/AIML/blob/main/GroupprojectGroup_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Train.py

In [None]:
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import time

from tqdm.auto import tqdm

from model import build_model
from datasets import get_datasets, get_data_loaders
from utils import save_model, save_plots

parser = argparse.ArgumentParser()
parser.add_argument(
    '-e', '--epochs', type=int, default=20,
    help='Number of epochs to train our network for'
)
parser.add_argument(
    '-lr', '--learning-rate', type=float,
    dest='learning_rate', default=0.001,
    help='Learning rate for training the model'
)
args = vars(parser.parse_args())

# Get the datasets and data loaders
train_dataset, valid_dataset = get_datasets()
train_loader, valid_loader = get_data_loaders(train_dataset, valid_dataset)

# Build the model
model = build_model(pretrained=True, fine_tune=False, num_classes=10)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args['learning_rate'])

# Train the model
for epoch in range(args['epochs']):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        images, labels = batch
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}')

    model.eval()
    total_correct = 0
    with torch.no_grad():
        for batch in valid_loader:
            images, labels = batch
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
    accuracy = total_correct / len(valid_dataset)
    print(f'Epoch {epoch+1}, Validation Accuracy: {accuracy:.4f}')

# Save the model
save_model(model, f'../outputs/model.pth')

# Save the plots
save_plots(train_losses, valid_losses, train_accuracies, valid_accuracies)

Model.py


In [None]:
import torchvision.models as models
import torch.nn as nn

def build_model(pretrained=True, fine_tune=False, num_classes=10):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.mobilenet_v3_large(pretrained=pretrained)

    if fine_tune:
        for param in model.parameters():
            param.requires_grad = True

    # Replace the last layer with a new one
    model.classifier = nn.Sequential(
        nn.Linear(model.classifier[0].in_features, num_classes)
    )

    return model

Dataset.py


In [None]:
 import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

ROOT_DIR = '../input/rice_leaf_diseases'
VALID_SPLIT = 0.1
IMAGE_SIZE = 224

def get_datasets():
    data_transforms = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = datasets.ImageFolder(ROOT_DIR, transform=data_transforms)
    indices = list(range(len(dataset)))
    split = int(VALID_SPLIT * len(dataset))
    train_indices, valid_indices = indices[split:], indices[:split]

    train_dataset = Subset(dataset, train_indices)
    valid_dataset = Subset(dataset, valid_indices)

    return train_dataset, valid_dataset

def get_data_loaders(train_dataset, valid_dataset):
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

    return train_loader, valid_loader


Utlis.py

In [None]:
import matplotlib.pyplot as plt

def save_model(model, path):
    torch.save(model.state_dict(), path)

def save_plots(train_losses, valid_losses, train_accuracies, valid_accuracies):
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('../outputs/loss.png')

    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(valid_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig('../outputs/accuracy.png')


interfence.py

In [None]:
import torch
from PIL import Image
from torchvision import transforms

def