In [1]:
from src.AppleModel import *
from torch.utils.data import DataLoader, random_split
import os
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from PIL import Image
import os
import matplotlib.pyplot as plt
import tqdm

def collate_fn(batch):
    return tuple(zip(*batch))


image_dir = "acfr-multifruit-2016/acfr-fruit-dataset/apples/images"
annotation_dir = "acfr-multifruit-2016/acfr-fruit-dataset/apples/annotations"

transform = transforms.Compose([
    # transforms.Resize((202,308)),
    transforms.ToTensor(),
])

dataset = AppleDataset(image_dir, annotation_dir, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4,collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
model = AppleDetector(num_classes=2)

# Move model to the desired device (GPU/CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Set the model to training mode
model.train()

# Define the optimizer (e.g., Adam)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

num_epochs = 10

# Learning rate scheduler (optional)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [3]:
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0

    for images, annotations in train_loader:

        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        loss_dict = model(images, targets)
        
        # Total loss is the sum of all the losses in the dictionary
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass
        losses.backward()
        
        # Update weights
        optimizer.step()
        
        # Keep track of the total loss for this batch
        total_loss += losses.item()

    # Average loss for the epoch
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    # Step the learning rate scheduler (optional)
    if lr_scheduler:
        lr_scheduler.step()


KeyboardInterrupt: 