In [1]:
#from data import CocoDatasetManager
#from collections import Counter
from backbone import ClassificationRCNN
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import UnidentifiedImageError
import collections


In [2]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to fit the model input
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet's mean and std
])

# Load the dataset
train_dataset = datasets.ImageFolder(root='/home/amanandhar/Transferlearning/RCNN-classification/src/data/train', transform=transform)

# Create a DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

In [3]:
dataset = train_loader.dataset
classes = dataset.classes
num_classes = len(classes)

print(f"Classes: {classes}")
print(f"Number of classes: {num_classes}")


Classes: ['airplane', 'apple', 'backpack', 'banana', 'baseball bat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair', 'clock', 'couch', 'cow', 'cup', 'dining table', 'dog', 'donut', 'elephant', 'fire hydrant', 'fork', 'frisbee', 'giraffe', 'hair drier', 'handbag', 'horse', 'hot dog', 'keyboard', 'kite', 'knife', 'laptop', 'microwave', 'motorcycle', 'mouse', 'orange', 'oven', 'parking meter', 'person', 'pizza', 'potted plant', 'refrigerator', 'remote', 'sandwich', 'scissors', 'sheep', 'sink', 'skateboard', 'skis', 'snowboard', 'spoon', 'sports ball', 'stop sign', 'suitcase', 'surfboard', 'teddy bear', 'tennis racket', 'tie', 'toaster', 'toilet', 'toothbrush', 'traffic light', 'truck', 'tv', 'umbrella', 'vase', 'wine glass', 'zebra']
Number of classes: 79


In [4]:
# Assuming you have defined a model, loss function, and optimizer
model = ClassificationRCNN(num_classes=79)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



In [5]:
num_epochs = 50

# Check if the DataLoader is empty
if len(train_loader) == 0:
    print("The DataLoader is empty. Check your dataset.")
else:
    print(f"Starting training for {num_epochs} epochs.")

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1} started.")
    for i, (images, labels) in enumerate(train_loader):
        try:
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 10 == 0:  # Print every 10 batches
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch {i}, Loss: {loss.item():.4f}")
        except Exception as e:
            print(f"An error occurred: {e}")

    print(f'Epoch [{epoch+1}/{num_epochs}] completed, Loss: {loss.item():.4f}')


KeyboardInterrupt: 

In [7]:
num_epochs = 50

for epoch in range(num_epochs):
    print("epoch: ", epoch)
    model.train()  # Set the model to training mode
    total_loss = 0
    num_batches = 32

    for images, labels in train_loader:
        try:
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        except Exception as e:
            print(f"Skipping a batch due to an error: {e}")
            continue  # Skip this batch and continue with the next

    if num_batches > 0:
        avg_train_loss = total_loss / num_batches
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}')
    else:
        print(f'Epoch [{epoch+1}/{num_epochs}], No valid batches processed.')

# Add validation phase and model saving as needed



epoch:  0


KeyboardInterrupt: 