I think a SWIN transformer might be good, it seems like it has versatile use cases. I think what I am doing is semantic segmentation, but instead of pixel values of 0 and 1, I think it is a distribution between 0 and 1 for a heatmap of KPs.

https://github.com/microsoft/Swin-Transformer

In [3]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import timm
import os


This dataset class loads grayscale images and heatmap labels, transforming the grayscale images to 3 channels (to match the Swin Transformer’s expectations). - not sure if this is a great way to do it, should figure out how to just keep the original RGB structure before Raghav dataset creation.

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Define Dataset class for loading facemap softlabels from a .pt file
class FaceMapDataset(Dataset):
    def __init__(self, data_file="data/facemap_softlabels.pt", transform=None):
        super().__init__()
        self.transform = transform
        self.data, _, self.targets = torch.load(data_file)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        image, label = self.data[index].clone(), self.targets[index].clone()
        
        # Convert grayscale to 3-channel by repeating the single channel
        image = image.repeat(3, 1, 1)  # Converts (1, 224, 224) to (3, 224, 224)

        # Apply transformations if provided (e.g., flipping for augmentation)
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)
        
        return image, label


# Initialize the dataset with transformations for resizing and normalization if needed
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize if required (optional)
])

dataset = FaceMapDataset(data_file="data/facemap_softlabels.pt", transform=transform)

# Define DataLoader
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)


Trying a SWIN model with some changes to output layer to hopefully match the output it should produce (224x224 heatmap of KP likely placement)

In [14]:

import torch
import torch.nn as nn
import timm

class SimpleSwinHeatmap(nn.Module):
    def __init__(self, pretrained=True):
        super(SimpleSwinHeatmap, self).__init__()
        
        # Load the Swin Transformer model as a feature extractor
        self.encoder = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained, features_only=True, out_indices=(3,))
        
        # Final convolutional layer to reduce to single-channel heatmap output
        self.conv_out = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1)  # Corrected input channels to 1024

    def forward(self, x):
        # Extract features from the last stage of Swin
        x = self.encoder(x)[0]  # Access the last feature map with appropriate dimensions
        
        # Apply the final convolution to produce a single-channel heatmap
        x = self.conv_out(x)
        
        # Resize to match the target output size if necessary
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        return x

# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleSwinHeatmap().to(device)

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [15]:
# criterion = nn.MSELoss()
# optimizer = Adam(model.parameters(), lr=1e-4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


SimpleSwinHeatmap(
  (encoder): FeatureListNet(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (layers_0): SwinTransformerStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path1): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate

In [16]:
# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Compute loss and backpropagate
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# Save the model
torch.save(model.state_dict(), 'simple_swin_heatmap.pth')


RuntimeError: Given groups=1, weight of size [1, 1024, 1, 1], expected input[8, 7, 7, 1024] to have 1024 channels, but got 7 channels instead

In [10]:
# # Training loop
# num_epochs = 10  # Set number of epochs

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0

#     for images, masks in train_loader:
#         images, masks = images.to(device), masks.to(device)

#         # Forward pass
#         optimizer.zero_grad()
#         outputs = model(images)

#         # Resize outputs if necessary to match mask dimensions
#         outputs = nn.functional.interpolate(outputs, size=(224, 224), mode='bilinear', align_corners=False)

#         # Compute loss and backpropagate
#         loss = criterion(outputs, masks)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# # Save the model after training
# # torch.save(model.state_dict(), 'fine_tuned_swin_segmentation.pth')

# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     for images, masks in train_loader:
#         images, masks = images.to(device), masks.to(device)

#         optimizer.zero_grad()
#         outputs = model(images)
        
#         # Resize outputs to match mask dimensions if necessary
#         outputs = nn.functional.interpolate(outputs, size=(224, 224), mode='bilinear', align_corners=False)

#         # Compute loss and backpropagate
#         loss = criterion(outputs, masks)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# # Save the model
# torch.save(model.state_dict(), 'simple_swin_heatmap.pth')


RuntimeError: Given groups=1, weight of size [1, 1024, 1, 1], expected input[8, 7, 7, 1024] to have 1024 channels, but got 7 channels instead

Trying chat gpts suggest unet and SIWN lovechild



In [7]:
import torch
import torch.nn as nn
import timm

class SwinUNet(nn.Module):
    def __init__(self, pretrained=True):
        super(SwinUNet, self).__init__()
        
        # Load Swin Transformer with multi-stage feature extraction
        self.encoder = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3))
        
        # Decoder part with transposed convolutions for upsampling
        self.upconv3 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(512 + 512, 256, kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(256 + 256, 128, kernel_size=2, stride=2)
        self.upconv0 = nn.ConvTranspose2d(128 + 128, 64, kernel_size=2, stride=2)
        
        # Final output layer
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)
        
        # Activation
        self.act = nn.ReLU()

    def forward(self, x):
        # Encoder: extract features from different stages
        enc_features = self.encoder(x)
        
        # Decoder with skip connections
        d3 = self.act(self.upconv3(enc_features[3]))               # Upsample last encoder layer
        d3 = torch.cat((d3, enc_features[2]), dim=1)               # Concatenate with encoder stage 3
        
        d2 = self.act(self.upconv2(d3))                            # Upsample
        d2 = torch.cat((d2, enc_features[1]), dim=1)               # Concatenate with encoder stage 2
        
        d1 = self.act(self.upconv1(d2))                            # Upsample
        d1 = torch.cat((d1, enc_features[0]), dim=1)               # Concatenate with encoder stage 1
        
        d0 = self.act(self.upconv0(d1))                            # Final upsample
        
        # Final layer to produce single-channel output
        out = self.final_conv(d0)
        
        return out

# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinUNet().to(device)

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [8]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        # Compute loss and backpropagate
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

# Save the model
torch.save(model.state_dict(), 'swin_unet_heatmap.pth')


RuntimeError: Given transposed=1, weight of size [1024, 512, 2, 2], expected input[8, 7, 7, 1024] to have 1024 channels, but got 7 channels instead