In [None]:
%pip install -q -r requirements.txt

In [None]:
import torch
import os
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

# Set up the device
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Using device:", device)

In [None]:
class MSTEDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # Walk the directory and gather (image_path, person_id)
        for person_id in sorted(os.listdir(root_dir)):
            person_folder = os.path.join(root_dir, person_id)
            if os.path.isdir(person_folder) and not person_id.startswith("cheek"):
                for img_name in sorted(os.listdir(person_folder)):
                    if not img_name.endswith(".mp4"):  # Only process .jpg files
                        img_path = os.path.join(person_folder, img_name)
                        self.samples.append((img_path, person_id))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, person_id = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image, person_id  # You can encode person_id if needed

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

mste_e = MSTEDataset(root_dir="/Users/sree/mst-e", transform=transform)
mste_e_loader = DataLoader(mste_e, batch_size=32, shuffle=True)

In [None]:
# # Display all images in the mste_e dataset
# for idx, (image, label) in enumerate(mste_e):
#     plt.figure(figsize=(3, 3))
#     plt.imshow(image.permute(1, 2, 0).numpy())  # Convert tensor to numpy array and adjust dimensions
#     plt.title(f"Label: {label}")
#     plt.axis("off")
#     plt.show()