In [None]:
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

# Define the transforms for normalization
data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Custom dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.classes = self.data['hotel_id'].unique().tolist()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, str(self.data.iloc[idx, 1]), str(self.data.iloc[idx, 0]))
        image = Image.open(img_name).convert("RGB")
        label = self.classes.index(self.data.iloc[idx, 1])
        
        if self.transform:
            image = self.transform(image)

        return image, label

# Define the training dataset and data loader
train_dataset = CustomDataset('train.csv', 'final/train_images', transform=data_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Load the model and set up the optimizer and criterion
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(train_dataset.classes))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Check if a saved model checkpoint exists and load it
checkpoint_path = 'resnet_checkpoint/best_model_checkpoint.pth'
if os.path.exists(checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path))
    print("Loaded the best model checkpoint!")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 10
train_losses = []
val_losses = []
best_accuracy = 0.0  # Initialize best accuracy

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(f"Iteration {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

    # Save model checkpoint if validation accuracy improves
    val_dataset = CustomDataset('validation.csv', 'final/validation_images', transform=data_transforms)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

    correct = 0
    total = 0
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _,predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            val_loss += criterion(outputs, labels).item()

        accuracy = 100 * correct / total
        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        print(f"Validation Accuracy: {accuracy:.2f}%")

        # Save the model checkpoint if accuracy improves
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), 'resnet_checkpointv3/best_model_checkpoint.pth')
            print("Model checkpoint saved!")

# Plotting the train and validation loss curves
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Validation Loss Curves')
plt.legend()
plt.show()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


Loaded the best model checkpoint!
Iteration 1/1072, Loss: 4.1020
Iteration 2/1072, Loss: 3.8162
Iteration 3/1072, Loss: 3.5488
Iteration 4/1072, Loss: 3.8963
Iteration 5/1072, Loss: 3.4605
Iteration 6/1072, Loss: 3.3014
Iteration 7/1072, Loss: 3.7696
Iteration 8/1072, Loss: 3.7900
Iteration 9/1072, Loss: 3.6338
Iteration 10/1072, Loss: 3.2801
Iteration 11/1072, Loss: 4.3602
Iteration 12/1072, Loss: 3.3779
Iteration 13/1072, Loss: 3.6566
Iteration 14/1072, Loss: 3.8195
Iteration 15/1072, Loss: 3.2265
Iteration 16/1072, Loss: 3.7001
Iteration 17/1072, Loss: 3.5733
Iteration 18/1072, Loss: 3.8221
Iteration 19/1072, Loss: 3.9015
Iteration 20/1072, Loss: 3.4108
Iteration 21/1072, Loss: 4.0448
Iteration 22/1072, Loss: 3.5658
Iteration 23/1072, Loss: 3.7551
Iteration 24/1072, Loss: 4.0526
Iteration 25/1072, Loss: 3.5170
Iteration 26/1072, Loss: 3.7856
Iteration 27/1072, Loss: 3.6658
Iteration 28/1072, Loss: 4.0142
Iteration 29/1072, Loss: 4.3232
Iteration 30/1072, Loss: 3.7686
Iteration 31/10