In [9]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from featout_exp.utils import CustomDataset, load_model

In [10]:
# Reading the dataframe
df = pd.read_csv('csvs/first_half.csv', index_col= 0)  # Replace 'your_data_file.csv' with your dataframe's path
df

Unnamed: 0,Path,Label
7597,PetImages/Cat/5586.jpg,Cat
1118,PetImages/Dog/11002.jpg,Dog
12138,PetImages/Dog/9673.jpg,Dog
2828,PetImages/Cat/1293.jpg,Cat
4089,PetImages/Cat/2428.jpg,Cat
...,...,...
9602,PetImages/Cat/7390.jpg,Cat
4008,PetImages/Cat/2355.jpg,Cat
2888,PetImages/Dog/1347.jpg,Dog
5773,PetImages/Dog/3944.jpg,Dog


In [11]:
dataset = CustomDataset(dataframe=df)

# Create DataLoader to handle batching
batch_size = 32  # You can modify this as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = load_model(device)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    corrects = 0
    total = 0

    for inputs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        corrects += torch.sum(preds == labels.data)
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = corrects.double() / total
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

print("Training complete!")


Epoch 1/3:   4%|▍         | 15/391 [00:13<05:30,  1.14it/s]


KeyboardInterrupt: 