In [3]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F
import timm
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class facemapdataset(Dataset):
    def __init__(self, data_file="data/dolensek_facemap_224.pt", transform=None):
        super().__init__()
        self.data, self.targets = torch.load(data_file)
        self.transform = transform

        # Filter out entries with NaN labels
        valid_indices = [i for i, label in enumerate(self.targets) if not np.any(np.isnan(label))]
        self.data = self.data[valid_indices]
        self.targets = self.targets[valid_indices]
        self.targets = torch.Tensor(self.targets)

    def __len__(self):
        return len(self.data) # Increase dataset size by 10x

    def __getitem__(self, index):
        base_index = index % len(self.data)  # Map index to original data
        aug_type = index // len(self.data)   # Determine augmentation type

        image, keypoints = self.data[base_index].clone(), self.targets[base_index].clone()

        # Apply augmentations based on aug_type
        if self.transform is not None:
            if aug_type == 1:  # Horizontal flip
                image = image.flip([2])
                keypoints[::2] = 224 - keypoints[::2]

            elif aug_type == 2:  # Random rotation
                angle = random.uniform(-30, 30)
                image = TF.rotate(image, angle)
                keypoints = self.rotate_keypoints(keypoints, angle)

            elif aug_type == 3:  # Zoom
                scale_factor = random.uniform(0.9, 1.1)
                image, keypoints = self.zoom(image, keypoints, scale_factor)

            elif aug_type == 4:  # Gaussian blur
                image = TF.gaussian_blur(image, kernel_size=3)

            elif aug_type == 5:  # Random cutout
                image = self.random_cutout(image)

            elif aug_type == 6:  # Adjust brightness
                image = TF.adjust_brightness(image, random.uniform(0.8, 1.2))

            elif aug_type == 7:  # Motion blur
                image = self.motion_blur(image)

            elif aug_type == 8:  # Random jitter
                image = self.random_jitter(image)

            elif aug_type == 9:  # Combined transformations
                angle = random.uniform(-15, 15)
                scale_factor = random.uniform(0.9, 1.1)
                image = TF.adjust_brightness(image, random.uniform(0.9, 1.1))
                image = self.random_jitter(image)
                image = TF.rotate(image, angle)
                keypoints = self.rotate_keypoints(keypoints, angle)
                image, keypoints = self.zoom(image, keypoints, scale_factor)

        return image, keypoints

    def rotate_keypoints(self, keypoints, angle):
        radians = torch.tensor(angle * np.pi / 180)
        rotation_matrix = torch.tensor([
            [torch.cos(radians), -torch.sin(radians)],
            [torch.sin(radians), torch.cos(radians)]
        ])
        keypoints_xy = keypoints.view(-1, 2)
        keypoints_rotated = torch.matmul(keypoints_xy - 112, rotation_matrix) + 112
        return keypoints_rotated.view(-1)

    def zoom(self, image, keypoints, scale_factor):
        _, h, w = image.shape
        new_h, new_w = int(h * scale_factor), int(w * scale_factor)
        image = TF.resize(image, [new_h, new_w])
        image = TF.center_crop(image, [h, w])
        keypoints *= scale_factor
        return image, keypoints

    def random_cutout(self, image):
        _, h, w = image.shape
        cutout_h = random.randint(10, 50)
        cutout_w = random.randint(10, 50)
        top = random.randint(0, h - cutout_h)
        left = random.randint(0, w - cutout_w)
        image[:, top:top + cutout_h, left:left + cutout_w] = 0
        return image

    def motion_blur(self, image, kernel_size=5, angle=45):
        kernel = torch.zeros((kernel_size, kernel_size))
        center = kernel_size // 2
        kernel[center, :] = 1
        kernel = kernel / kernel.sum()
        kernel = kernel.unsqueeze(0).unsqueeze(0)
        return F.conv2d(image.unsqueeze(0), kernel, padding=kernel_size // 2).squeeze(0)

    def random_jitter(self, image, max_jitter=0.1):
        noise = torch.randn_like(image) * max_jitter
        image = image + noise
        return torch.clamp(image, 0, 1)


# Make dataset
dataset = facemapdataset()  # This will now automatically filter out entries with NaN values

x = dataset[0][0]
dim = x.shape[-1]
print('Using %d size of images' % dim)

N = len(dataset)

# Random sampling for train/valid/test splits
indices = np.random.permutation(N)
train_indices = indices[:int(0.6 * N)]
valid_indices = indices[int(0.6 * N):int(0.8 * N)]
test_indices = indices[int(0.8 * N):]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)

batch_size = 4

# Initialize loss and metrics
loss_fun = torch.nn.MSELoss(reduction='sum')

# Initialize input dimensions
num_train = len(train_sampler)
num_valid = len(valid_sampler)
num_test = len(test_sampler)
print("Num. train = %d, Num. val = %d, Num. test = %d" % (num_train, num_valid, num_test))

# Initialize dataloaders
loader_train = DataLoader(dataset=dataset, drop_last=False, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=train_sampler)
loader_valid = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=valid_sampler)
loader_test = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                         batch_size=1, pin_memory=True, sampler=test_sampler)

nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

### Hyperparameters
lr = 5e-4
num_epochs = 300
num_input_channels = 1  # Change this to the desired number of input channels
num_output_classes = 24  # Change this to the desired number of output classes

model = timm.create_model('vit_base_patch8_224',
                          pretrained=True, in_chans=1, num_classes=num_output_classes)

model = model.to(device)
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d M" % (nParam / 1e6))

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
minLoss = 1e6
convIter = 0
patience = 1000
train_loss = []
valid_loss = []

for epoch in range(num_epochs):
    tr_loss = 0
    for i, (inputs, labels) in enumerate(loader_train):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(
            torch.log(scores[labels != 0]), torch.log(F.softplus(labels[labels != 0]))
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
              .format(epoch + 1, num_epochs, i + 1, nTrain, loss.item()))
        tr_loss += loss.item()
    train_loss.append(tr_loss / (i + 1))

    with torch.no_grad():
        val_loss = 0
        for i, (inputs, labels) in enumerate(loader_valid):
            inputs = inputs.to(device)
            labels = labels.to(device)
            scores = F.softplus(model(inputs))
            loss = loss_fun(
                torch.log(scores[labels != 0]),
                torch.log(F.softplus(labels[labels != 0])),

            )
            val_loss += loss.item()
        val_loss = val_loss / (i + 1)
        
        valid_loss.append(val_loss)

        print('Val. loss :%.4f' % val_loss)
        
        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.cpu().numpy()
        plt.clf()
        plt.figure(figsize=(16, 6))
        for i in range(batch_size):
            plt.subplot(1, batch_size, i + 1)
            plt.imshow(img[i], cmap='gray')
            plt.plot(pred[i, ::2], pred[i, 1::2], 'x', c='tab:red', label='pred.')
            plt.plot(labels[i, ::2], labels[i, 1::2], 'o', c='tab:green', label='label')
        plt.tight_layout()
        plt.savefig('logs/epoch_%03d.jpg' % epoch)
        plt.close()
            
        if minLoss > val_loss:
            convEpoch = epoch
            minLoss = val_loss
            convIter = 0
            torch.save(model.state_dict(), 'models/best_model.pt')
        else:
            convIter += 1

        if convIter == patience:
            print('Converged at epoch %d with val. loss %.4f' % (convEpoch + 1, minLoss))
            break

plt.clf()
plt.plot(train_loss, label='Training')
plt.plot(valid_loss, label='Valid')
plt.plot(convEpoch, valid_loss[convEpoch], 'x', label='Final Model')
plt.legend()
plt.tight_layout()
plt.savefig('loss_curve.pdf')

### Load best model for inference
with torch.no_grad():
    val_loss = 0

    for i, (inputs, labels) in enumerate(loader_test):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(torch.log(scores), torch.log(F.softplus(labels)))
        val_loss += loss.item()

        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.squeeze().cpu().numpy()
        plt.clf()
        plt.imshow(img, cmap='gray')
        plt.plot(pred[::2], pred[1::2], 'x', c='tab:red')
        plt.plot(labels[::2], labels[1::2], 'o', c='tab:green')
        plt.tight_layout()
        plt.savefig('preds/test_%03d.jpg' % i)
        plt.close()

    val_loss = val_loss / (i + 1)
    print('Test. loss :%.4f' % val_loss)


Using 224 size of images
Num. train = 165, Num. val = 55, Num. test = 55
Number of parameters:85 M
Epoch [1/300], Step [1/42], Loss: 2366.9194
Epoch [1/300], Step [2/42], Loss: 2034.3840
Epoch [1/300], Step [3/42], Loss: 1308.1278
Epoch [1/300], Step [4/42], Loss: 1051.6830
Epoch [1/300], Step [5/42], Loss: 759.5570
Epoch [1/300], Step [6/42], Loss: 717.6893
Epoch [1/300], Step [7/42], Loss: 553.8998
Epoch [1/300], Step [8/42], Loss: 561.1817
Epoch [1/300], Step [9/42], Loss: 480.5417
Epoch [1/300], Step [10/42], Loss: 466.3611
Epoch [1/300], Step [11/42], Loss: 421.8515
Epoch [1/300], Step [12/42], Loss: 460.6295
Epoch [1/300], Step [13/42], Loss: 349.2466
Epoch [1/300], Step [14/42], Loss: 419.2116
Epoch [1/300], Step [15/42], Loss: 295.9254
Epoch [1/300], Step [16/42], Loss: 340.0916
Epoch [1/300], Step [17/42], Loss: 374.8528
Epoch [1/300], Step [18/42], Loss: 316.1220
Epoch [1/300], Step [19/42], Loss: 340.6238
Epoch [1/300], Step [20/42], Loss: 346.9474
Epoch [1/300], Step [21/42

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [6]:
import random
import torchvision.transforms.functional as TF

import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F
import timm
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class facemapdataset(Dataset):
    def __init__(self, data_file="data/dolensek_facemap_224.pt", transform=None):
        super().__init__()
        self.data, self.targets = torch.load(data_file)
        self.transform = transform

        # Filter out entries with NaN labels
        valid_indices = [i for i, label in enumerate(self.targets) if not np.any(np.isnan(label))]
        self.data = self.data[valid_indices]
        self.targets = self.targets[valid_indices]
        self.targets = torch.Tensor(self.targets)

    def __len__(self):
        return len(self.data)  # 10x augmentation as in the segmentation script

    def __getitem__(self, index):
        base_index = index % len(self.data)  # Map index to original data
        aug_type = index // len(self.data)   # Determine augmentation type

        image, keypoints = self.data[base_index].clone(), self.targets[base_index].clone()

        # Apply the augmentation based on the segmentation script
        if aug_type == 1:  # Flipping
            image = image.flip([2])  # Horizontal flip
            keypoints[::2] = 224 - keypoints[::2]  # Adjust x-coordinates for flip

        elif aug_type == 2:  # Rotation
            angle = random.uniform(-30, 30)
            image = TF.rotate(image, angle)
            keypoints = self.rotate_keypoints(keypoints, angle)

        elif aug_type == 3:  # Zoom
            scale_factor = random.uniform(0.8, 1.5)
            image, keypoints = self.zoom(image, keypoints, scale_factor)

        elif aug_type == 4:  # Gaussian Blur
            image = TF.gaussian_blur(image, kernel_size=3)

        elif aug_type == 5:  # Random cutout
            image = self.random_cutout(image)

        elif aug_type == 6:  # Adjust brightness
            image = TF.adjust_brightness(image, random.uniform(0.8, 1.2))

        elif aug_type == 7:  # Motion blur
            image = self.motion_blur(image)

        elif aug_type == 8:  # Random jitter
            image = self.random_jitter(image)

        elif aug_type == 9:  # Combined transformations
            angle = random.uniform(-15, 15)
            scale_factor = random.uniform(0.9, 1.1)
            image = TF.adjust_brightness(image, random.uniform(0.9, 1.1))
            image = self.random_jitter(image)
            image = TF.rotate(image, angle)
            keypoints = self.rotate_keypoints(keypoints, angle)
            image, keypoints = self.zoom(image, keypoints, scale_factor)

        return image, keypoints

    def rotate_keypoints(self, keypoints, angle):
        """Rotate keypoints using a rotation matrix."""
        radians = torch.tensor(angle * np.pi / 180)
        rotation_matrix = torch.tensor([
            [torch.cos(radians), -torch.sin(radians)],
            [torch.sin(radians), torch.cos(radians)]
        ])
        keypoints_xy = keypoints.view(-1, 2)  # Reshape to N x 2
        keypoints_rotated = torch.matmul(keypoints_xy - 112, rotation_matrix) + 112
        return keypoints_rotated.view(-1)

    def zoom(self, image, keypoints, scale_factor):
        """Apply zoom to both image and keypoints."""
        _, h, w = image.shape
        new_h, new_w = int(h * scale_factor), int(w * scale_factor)
        image = TF.resize(image, [new_h, new_w])
        image = TF.center_crop(image, [h, w])
        keypoints *= scale_factor  # Scale keypoints proportionally
        return image, keypoints

    def random_cutout(self, image):
        """Apply random cutout to the image."""
        _, h, w = image.shape
        cutout_h = random.randint(10, 50)
        cutout_w = random.randint(10, 50)
        top = random.randint(0, h - cutout_h)
        left = random.randint(0, w - cutout_w)
        image[:, top:top + cutout_h, left:left + cutout_w] = 0
        return image

    def motion_blur(self, image, kernel_size=5, angle=45):
        """Apply motion blur to the image."""
        kernel = torch.zeros((kernel_size, kernel_size))
        center = kernel_size // 2
        kernel[center, :] = 1
        kernel = kernel / kernel.sum()
        kernel = kernel.unsqueeze(0).unsqueeze(0)
        return F.conv2d(image.unsqueeze(0), kernel, padding=kernel_size // 2).squeeze(0)

    def random_jitter(self, image, max_jitter=0.1):
        """Apply random jitter to the image."""
        noise = torch.randn_like(image) * max_jitter
        image = image + noise
        return torch.clamp(image, 0, 1)



# Make dataset
dataset = facemapdataset()  # This will now automatically filter out entries with NaN values

x = dataset[0][0]
dim = x.shape[-1]
print('Using %d size of images' % dim)

N = len(dataset)

# Random sampling for train/valid/test splits
indices = np.random.permutation(N)
train_indices = indices[:int(0.6 * N)]
valid_indices = indices[int(0.6 * N):int(0.8 * N)]
test_indices = indices[int(0.8 * N):]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)

batch_size = 4

# Initialize loss and metrics
loss_fun = torch.nn.MSELoss(reduction='sum')

# Initialize input dimensions
num_train = len(train_sampler)
num_valid = len(valid_sampler)
num_test = len(test_sampler)
print("Num. train = %d, Num. val = %d, Num. test = %d" % (num_train, num_valid, num_test))

# Initialize dataloaders
loader_train = DataLoader(dataset=dataset, drop_last=False, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=train_sampler)
loader_valid = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=valid_sampler)
loader_test = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                         batch_size=1, pin_memory=True, sampler=test_sampler)

nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

### Hyperparameters
lr = 5e-4
num_epochs = 300
num_input_channels = 1  # Change this to the desired number of input channels
num_output_classes = 24  # Change this to the desired number of output classes

model = timm.create_model('vit_base_patch8_224',
                          pretrained=True, in_chans=1, num_classes=num_output_classes)

model = model.to(device)
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d M" % (nParam / 1e6))

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
minLoss = 1e6
convIter = 0
patience = 1000
train_loss = []
valid_loss = []

for epoch in range(num_epochs):
    tr_loss = 0
    for i, (inputs, labels) in enumerate(loader_train):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(
            torch.log(scores[labels != 0]), torch.log(F.softplus(labels[labels != 0]))
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
              .format(epoch + 1, num_epochs, i + 1, nTrain, loss.item()))
        tr_loss += loss.item()
    train_loss.append(tr_loss / (i + 1))

    with torch.no_grad():
        val_loss = 0
        for i, (inputs, labels) in enumerate(loader_valid):
            inputs = inputs.to(device)
            labels = labels.to(device)
            scores = F.softplus(model(inputs))
            loss = loss_fun(
                torch.log(scores[labels != 0]),
                torch.log(F.softplus(labels[labels != 0])),

            )
            val_loss += loss.item()
        val_loss = val_loss / (i + 1)
        
        valid_loss.append(val_loss)

        print('Val. loss :%.4f' % val_loss)
        
        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.cpu().numpy()
        plt.clf()
        plt.figure(figsize=(16, 6))
        for i in range(batch_size):
            plt.subplot(1, batch_size, i + 1)
            plt.imshow(img[i], cmap='gray')
            plt.plot(pred[i, ::2], pred[i, 1::2], 'x', c='tab:red', label='pred.')
            plt.plot(labels[i, ::2], labels[i, 1::2], 'o', c='tab:green', label='label')
        plt.tight_layout()
        plt.savefig('logs/epoch_%03d.jpg' % epoch)
        plt.close()
            
        if minLoss > val_loss:
            convEpoch = epoch
            minLoss = val_loss
            convIter = 0
            torch.save(model.state_dict(), 'models/best_model.pt')
        else:
            convIter += 1

        if convIter == patience:
            print('Converged at epoch %d with val. loss %.4f' % (convEpoch + 1, minLoss))
            break

plt.clf()
plt.plot(train_loss, label='Training')
plt.plot(valid_loss, label='Valid')
plt.plot(convEpoch, valid_loss[convEpoch], 'x', label='Final Model')
plt.legend()
plt.tight_layout()
plt.savefig('loss_curve.pdf')

### Load best model for inference
with torch.no_grad():
    val_loss = 0

    for i, (inputs, labels) in enumerate(loader_test):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(torch.log(scores), torch.log(F.softplus(labels)))
        val_loss += loss.item()

        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.squeeze().cpu().numpy()
        plt.clf()
        plt.imshow(img, cmap='gray')
        plt.plot(pred[::2], pred[1::2], 'x', c='tab:red')
        plt.plot(labels[::2], labels[1::2], 'o', c='tab:green')
        plt.tight_layout()
        plt.savefig('preds/test_%03d.jpg' % i)
        plt.close()

    val_loss = val_loss / (i + 1)
    print('Test. loss :%.4f' % val_loss)

Using 224 size of images
Num. train = 165, Num. val = 55, Num. test = 55
Number of parameters:85 M
Epoch [1/300], Step [1/42], Loss: 2366.9194
Epoch [1/300], Step [2/42], Loss: 2034.3782
Epoch [1/300], Step [3/42], Loss: 1308.1287
Epoch [1/300], Step [4/42], Loss: 1051.7107
Epoch [1/300], Step [5/42], Loss: 759.5363
Epoch [1/300], Step [6/42], Loss: 717.7383
Epoch [1/300], Step [7/42], Loss: 553.8943
Epoch [1/300], Step [8/42], Loss: 561.1935
Epoch [1/300], Step [9/42], Loss: 480.5566
Epoch [1/300], Step [10/42], Loss: 466.3644
Epoch [1/300], Step [11/42], Loss: 421.8497
Epoch [1/300], Step [12/42], Loss: 460.6293
Epoch [1/300], Step [13/42], Loss: 349.2500
Epoch [1/300], Step [14/42], Loss: 419.2104
Epoch [1/300], Step [15/42], Loss: 295.9288
Epoch [1/300], Step [16/42], Loss: 340.0952
Epoch [1/300], Step [17/42], Loss: 374.8594
Epoch [1/300], Step [18/42], Loss: 316.1287
Epoch [1/300], Step [19/42], Loss: 340.6282
Epoch [1/300], Step [20/42], Loss: 346.9501
Epoch [1/300], Step [21/42

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [None]:
import pdb
import glob
import pandas as pd
from numpy import random
import numpy as np
import timm
import matplotlib.pyplot as plt

torch.manual_seed(42)
np.random.seed(42)


def random_cutout(image, label, max_h=50, max_w=50):
    """
    Apply random cutout to both image and label.
    Args:
        image: The input image tensor.
        label: The label tensor.
        max_h: Maximum height of the cutout box.
        max_w: Maximum width of the cutout box.
    Returns:
        image: Image after cutout.
        label: Label after cutout.
    """
    _, h, w = image.shape
    cutout_height = random.randint(10, max_h)
    cutout_width = random.randint(10, max_w)

    # Randomly choose the position for the cutout
    top = random.randint(0, h - cutout_height)
    left = random.randint(0, w - cutout_width)

    # Apply the cutout to the image and label (set to 0)
    image[:, top:top + cutout_height, left:left + cutout_width] = 0
    label[:, top:top + cutout_height, left:left + cutout_width] = 0

    return image, label

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class facemapdataset(Dataset):
    def __init__(self, data_file="data/dolensek_facemap_224.pt", transform=None):
        super().__init__()

        self.transform = transform
        self.data, self.targets = torch.load(data_file)
        
        # Filter out entries where any labels contain NaN values
        valid_indices = [i for i, label in enumerate(self.targets) if not np.any(np.isnan(label))]
        
        # Keep only the valid entries
        self.data = self.data[valid_indices]
        self.targets = self.targets[valid_indices]
        
        # Ensure the targets are tensors and NaNs are replaced with zeros
        self.targets = torch.Tensor(self.targets)
        self.targets = torch.nan_to_num(self.targets, nan=0)  # Optionally replace NaNs with 0 or any value
    
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        base_index = index % len(self.data)  # This will prevent out-of-bounds errors
        aug_type = index // len(self.data)   # This will determine which augmentation to apply

        # Load the original image and label
        image, label = self.data[base_index].clone(), self.targets[base_index].clone()
        #image, label = self.data[index].clone(), self.targets[index].clone()
        if self.transform is not None:
            if aug_type == 1:  # Flipping
                image = image.flip([2])  # Horizontal flip
                label[::2] = 224 - label[::2]  # Flip keypoints horizontally
                            # Apply random cutout with probability
            if random.random() < self.cutout_prob:
                image, label = random_cutout(image, label)
        return image, label

# Make dataset
dataset = facemapdataset()  # This will now automatically filter out entries with NaN values

x = dataset[0][0]
dim = x.shape[-1]
print('Using %d size of images' % dim)

N = len(dataset)
#train_sampler = SubsetRandomSampler(np.arange(int(0.6 * N)))
#valid_sampler = SubsetRandomSampler(np.arange(int(0.6 * N), int(0.8 * N)))
#test_sampler = SubsetRandomSampler(np.arange(int(0.8 * N), N))

#try randomization
indices = np.random.permutation(N)
train_indices = indices[:int(0.6*N)]
valid_indices = indices[int(0.6*N):int(0.8*N)]
test_indices = indices[int(0.8*N):]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)
test_sampler = SubsetRandomSampler(test_indices)

batch_size = 4

# Initialize loss and metrics
loss_fun = torch.nn.MSELoss(reduction='sum')

# Initialize input dimensions
num_train = len(train_sampler)
num_valid = len(valid_sampler)
num_test = len(test_sampler)
print("Num. train = %d, Num. val = %d, Num. test = %d" % (num_train, num_valid, num_test))

# Initialize dataloaders
loader_train = DataLoader(dataset=dataset, drop_last=False, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=train_sampler)
loader_valid = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                          batch_size=batch_size, pin_memory=True, sampler=valid_sampler)
loader_test = DataLoader(dataset=dataset, drop_last=True, num_workers=0,
                         batch_size=1, pin_memory=True, sampler=test_sampler)

nValid = len(loader_valid)
nTrain = len(loader_train)
nTest = len(loader_test)

### Hyperparameters
lr = 5e-4
num_epochs = 300
num_input_channels = 1  # Change this to the desired number of input channels
num_output_classes = 24  # Change this to the desired number of output classes

model = timm.create_model('vit_base_patch8_224',
                          pretrained=True, in_chans=1, num_classes=num_output_classes)

model = model.to(device)
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d M" % (nParam / 1e6))

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
minLoss = 1e6
convIter = 0
patience = 1000
train_loss = []
valid_loss = []

for epoch in range(num_epochs):
    tr_loss = 0
    for i, (inputs, labels) in enumerate(loader_train):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(
            torch.log(scores[labels != 0]), torch.log(F.softplus(labels[labels != 0]))
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
              .format(epoch + 1, num_epochs, i + 1, nTrain, loss.item()))
        tr_loss += loss.item()
    train_loss.append(tr_loss / (i + 1))

    with torch.no_grad():
        val_loss = 0
        for i, (inputs, labels) in enumerate(loader_valid):
            inputs = inputs.to(device)
            labels = labels.to(device)
            scores = F.softplus(model(inputs))
            loss = loss_fun(
                torch.log(scores[labels != 0]),
                torch.log(F.softplus(labels[labels != 0])),
            )
            val_loss += loss.item()
        val_loss = val_loss / (i + 1)
        
        valid_loss.append(val_loss)

        print('Val. loss :%.4f' % val_loss)
        
        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.cpu().numpy()
        plt.clf()
        plt.figure(figsize=(16, 6))
        for i in range(batch_size):
            plt.subplot(1, batch_size, i + 1)
            plt.imshow(img[i], cmap='gray')
            plt.plot(pred[i, ::2], pred[i, 1::2], 'x', c='tab:red', label='pred.')
            plt.plot(labels[i, ::2], labels[i, 1::2], 'o', c='tab:green', label='label')
        plt.tight_layout()
        plt.savefig('logs/epoch_%03d.jpg' % epoch)
        plt.close()
            
        if minLoss > val_loss:
            convEpoch = epoch
            minLoss = val_loss
            convIter = 0
            torch.save(model.state_dict(), 'models/best_model.pt')
        else:
            convIter += 1

        if convIter == patience:
            print('Converged at epoch %d with val. loss %.4f' % (convEpoch + 1, minLoss))
            break

plt.clf()
plt.plot(train_loss, label='Training')
plt.plot(valid_loss, label='Valid')
plt.plot(convEpoch, valid_loss[convEpoch], 'x', label='Final Model')
plt.legend()
plt.tight_layout()
plt.savefig('loss_curve.pdf')

### Load best model for inference
with torch.no_grad():
    val_loss = 0

    for i, (inputs, labels) in enumerate(loader_test):
        inputs = inputs.to(device)
        labels = labels.to(device)
        scores = F.softplus(model(inputs))
        loss = loss_fun(torch.log(scores), torch.log(F.softplus(labels)))
        val_loss += loss.item()

        img = inputs.squeeze().detach().cpu().numpy()
        pred = scores.squeeze().detach().cpu().numpy()
        labels = labels.squeeze().cpu().numpy()
        plt.clf()
        plt.imshow(img, cmap='gray')
        plt.plot(pred[::2], pred[1::2], 'x', c='tab:red')
        plt.plot(labels[::2], labels[1::2], 'o', c='tab:green')
        plt.tight_layout()
        plt.savefig('preds/test_%03d.jpg' % i)
        plt.close()

    val_loss = val_loss / (i + 1)
    print('Test. loss :%.4f' % val_loss)


Using 224 size of images
Num. train = 165, Num. val = 55, Num. test = 55
Number of parameters:85 M
Epoch [1/300], Step [1/42], Loss: 2366.9194
Epoch [1/300], Step [2/42], Loss: 2034.3792
Epoch [1/300], Step [3/42], Loss: 1308.1312
Epoch [1/300], Step [4/42], Loss: 1051.7091
Epoch [1/300], Step [5/42], Loss: 759.5433
Epoch [1/300], Step [6/42], Loss: 717.7389
Epoch [1/300], Step [7/42], Loss: 553.8943
Epoch [1/300], Step [8/42], Loss: 561.1953
Epoch [1/300], Step [9/42], Loss: 480.5583
Epoch [1/300], Step [10/42], Loss: 466.3651
Epoch [1/300], Step [11/42], Loss: 421.8498
Epoch [1/300], Step [12/42], Loss: 460.6294
Epoch [1/300], Step [13/42], Loss: 349.2504
Epoch [1/300], Step [14/42], Loss: 419.2108
Epoch [1/300], Step [15/42], Loss: 295.9291
Epoch [1/300], Step [16/42], Loss: 340.0952
Epoch [1/300], Step [17/42], Loss: 374.8594
Epoch [1/300], Step [18/42], Loss: 316.1288
Epoch [1/300], Step [19/42], Loss: 340.6282
Epoch [1/300], Step [20/42], Loss: 346.9500
Epoch [1/300], Step [21/42