In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Creating dataloader to resize image and mask together
Using relis-2d

In [ ]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from PIL import Image

RESIZE = (224, 224)

def resize_image_and_mask(image, mask):
    resize_transform = transforms.Resize(RESIZE, interpolation=Image.BILINEAR)
    resized_image = resize_transform(image)
    # Use NEAREST interpolation for masks to preserve label integrity
    resize_transform_mask = transforms.Resize(RESIZE, interpolation=Image.NEAREST)
    resized_mask = resize_transform_mask(mask)
    return resized_image, resized_mask

class CustomDataset(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        self.dataset = datasets.OxfordIIITPet(root=root, download=True, target_types='segmentation')
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, mask = self.dataset[idx]
        # Apply the custom resize function
        image, mask = resize_image_and_mask(image, mask)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)
        return image, mask
        

In [ ]:
import os

SIMPLE_DATA_FILE_PATH = "~/Projects/WildFusionPlus/input/2d_model_dummy/"


rgb_image_transform = transforms.Compose([
    transforms.Resize((IMAGE_RESIZE[0], IMAGE_RESIZE[1])),  # Resize to a fixed size
    transforms.ToTensor(),         # Convert image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize mask to match the image size
    transforms.ToTensor()           # Convert mask to Tensor (values between 0 and 1)
])

# Load the Oxford-IIIT Pet dataset
dataset = datasets.OxfordIIITPet(
    root=SIMPLE_DATA_FILE_PATH,
    download=True,
    target_types='segmentation',  # Get segmentation masks
    transform=rgb_image_transform,    # Apply image transformations
    target_transform=mask_transform  # Apply mask transformations
)


In [2]:
class BlackAndWhiteCNN(nn.Module):
    """
    Acts as black and white image input encoder for the compression layer
    
    CNNs are better with batch normalization, use it for BW
    """
    def __init__(self):
        super(BlackAndWhiteCNN, self).__init__()
        pass

class ColorCNN(nn.Module):
    """
    Acts as color image input encoder for the compression layer
    - Use LAB
    - Deeper than BlackAndWhiteCNN
    - CNNs are better with batch normalization, however, first have to understand why colornet uses LayerNorm
    """
    def __init__(self):
        super(ColorCNN, self).__init__()
        pass

class ColorNet(nn.Module):
    """
    Labnet is used to quantize the RGB
    
    Design Question: Should the expert be trained on logits or softmax? If logits then we can produce a bigger weight
    which may make more sense. It can also be trained on softmax and then in the end to end it will not be connected
    with the softmax layer
    """
    def __init__(self, in_features=512, hidden_dim=256, num_bins=313):
        super(ColorNet, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim)
        self.bn1 = nn.LayerNorm(hidden_dim)  # Use LayerNorm
        self.dropout1 = nn.Dropout(0.1) 

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.LayerNorm(hidden_dim)
        self.dropout2 = nn.Dropout(0.1)

        self.fc3 = nn.Linear(hidden_dim, 3 * num_bins)

    def forward(self, x):
        if len(x.shape) > 2:
            x = x.view(-1, x.shape[-1])  # Flatten for LayerNorm
        x = F.leaky_relu(self.bn1(self.fc1(x)), negative_slope=0.01)
        x = self.dropout1(x)
        x = F.leaky_relu(self.bn2(self.fc2(x)), negative_slope=0.01)
        x = self.dropout2(x)
        x = self.fc3(x)
        x = x.view(-1, 3, 313)
        x = F.softmax(x, dim=-1)  # Apply Softmax to color bins
        return x
    
model = ColorNet()
torch.save(model.state_dict(), "colornet_model.pth")
