In [1]:
import sys
import os

# For Jupyter notebooks
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from datasets.waterbirds import WaterbirdsDataset
from models.resnet_classifier import ResNet50Classifier


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.waterbirds import WaterbirdsDataset
from models.resnet_classifier import ResNet50Classifier
from sklearn.metrics import accuracy_score

In [3]:
data_dir = "F:/HLCV/HLCV_project/HLCV_project/datasets/data/waterbird"  # change to your path
batch_size = 32
epochs = 10
lr = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# ---- Transforms ---- #
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [5]:
# ---- Datasets & Loaders ---- #
from datasets.waterbirds import WaterbirdsDataset

# --- Dataset & DataLoaders --- #
train_ds = WaterbirdsDataset(data_dir, split='train', transform=transform)
val_ds   = WaterbirdsDataset(data_dir, split='val', transform=transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)



In [6]:
model = ResNet50Classifier(pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)



In [7]:
# ---- Training Loop ---- #
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_loss:.4f}")


KeyboardInterrupt: 