In [None]:
#!pip install "numpy<2"

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

In [7]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# Model Definition
class KeyPointModel(nn.Module):
    def __init__(self):
        super(KeyPointModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 24)  # 12 key points (24 coordinates)

    def forward(self, x):
        return self.resnet(x)

# Instantiate the model
model = KeyPointModel()
model.train()  # Set to training mode

# Dataset Definition
class KeyPointDataset(Dataset):
    def __init__(self, image_paths, key_points, transform=None):
        self.image_paths = image_paths
        self.key_points = key_points
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        key_point = self.key_points[idx]

        if self.transform:
            image = self.transform(image)

        return image, key_point

# Function to load dataset from a specified folder
def load_dataset(folder_path):
    # Load key points from the CSV file
    csv_file = os.path.join(folder_path, 'augmented_labels.csv')  # Change to your CSV file name
    keypoint_data = pd.read_csv(csv_file)

    # Create lists for images and key points
    image_paths = []
    key_points = []

    # Iterate through the CSV to get paths and key points
    for index, row in keypoint_data.iterrows():
        image_name = row['image_name']  # Replace with the actual column name in your CSV
        keypoint = row[1:].values.astype(float)  # Assuming the first column is the image name
        image_path = os.path.join(folder_path, f"{image_name}")  # Assuming images are in PNG format

        image_paths.append(image_path)
        key_points.append(torch.tensor(keypoint))

    return image_paths, key_points

# Paths to your train and test folders
train_folder = '/Users/annastuckert/Documents/GitHub/ViT_facemap/ViT-pytorch/projects/Facemap/data/train/augmented_data'  # Change to your train folder path
test_folder = '/Users/annastuckert/Documents/GitHub/ViT_facemap/ViT-pytorch/projects/Facemap/data/test/augmented_data'    # Change to your test folder path

# Load datasets
train_image_paths, train_key_points = load_dataset(train_folder)
test_image_paths, test_key_points = load_dataset(test_folder)

# Transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create DataLoaders
train_dataset = KeyPointDataset(train_image_paths, train_key_points, transform)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

test_dataset = KeyPointDataset(test_image_paths, test_key_points, transform)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training Loop
num_epochs = 10  # Number of epochs
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch_idx, (images, targets) in enumerate(train_dataloader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets.float())  # Ensure targets are float
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Print progress every 10 batches (adjust as needed)
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

    # Print the average loss for the epoch
    avg_loss = running_loss / len(train_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

# Save the trained model if needed
torch.save(model.state_dict(), 'keypoint_model.pth')





Epoch [1/10], Batch [0/225], Loss: 10558.6113
