In [5]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import pickle
import math
import torch.nn as nn
from DataLoader import get_dataloader
import matplotlib.pyplot as plt
%matplotlib inline

class CustomDataset(Dataset):
    def __init__(self, image_dir, label_file, transform=None, cache_file='valid_indices_cache.pkl'):
        self.image_dir = image_dir
        self.transform = transform
        self.cache_file = cache_file
        with open(label_file, 'r', encoding='utf-8') as f:
            self.labels = f.readlines()
        self.valid_indices = self._load_or_create_valid_indices()

    def _load_or_create_valid_indices(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, 'rb') as f:
                valid_indices = pickle.load(f)
        else:
            valid_indices = self._get_valid_indices()
            with open(self.cache_file, 'wb') as f:
                pickle.dump(valid_indices, f)
        return valid_indices

    def _get_valid_indices(self):
        valid_indices = []
        for idx in range(len(self.labels)):
            img_name = os.path.join(self.image_dir, f"{idx:07d}.png")
            if os.path.exists(img_name):
                valid_indices.append(idx)
        return valid_indices

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        actual_idx = self.valid_indices[idx]
        img_name = os.path.join(self.image_dir, f"{actual_idx:07d}.png")
        image = Image.open(img_name).convert('RGB')
        label = self.labels[actual_idx].strip()
        image = self.resize_and_pad(image,800,400)
        if self.transform:
            image = self.transform(image)
        else:
            transform = transforms.Compose([
                transforms.ToTensor()
            ])
            image = transform(image)
        
        return image, label

    def resize_and_pad(self, image, target_width, target_height):
        # Calculate the ratio to maintain the aspect ratio
        original_width, original_height = image.size
        ratio = min(target_width / original_width, target_height / original_height)
    
        # Resize the image while maintaining the aspect ratio
        new_size = (int(original_width * ratio), int(original_height * ratio))
        resized_image = image.resize(new_size, Image.LANCZOS)
    
        # Create a new image with the specified target size and a white background
        new_image = Image.new("RGB", (target_width, target_height), (255, 255, 255))
    
        # Calculate the position to paste the resized image on the white background
        paste_x = (target_width - new_size[0]) // 2
        paste_y = (target_height - new_size[1]) // 2
    
        # Paste the resized image onto the white background
        new_image.paste(resized_image, (paste_x, paste_y))
    
        return new_image

def get_dataloader(batch_size, image_dir='../../UniMER-1M/images/', label_file='../../UniMER-1M/train.txt', transform=None, cache_file='valid_indices_cache.pkl'):
    dataset = CustomDataset(image_dir=image_dir, label_file=label_file, transform=transform, cache_file=cache_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader



In [6]:
class PositionalEncoding2D(nn.Module):
    def __init__(self, d_model, height, width):
        super().__init__()
        self.height = height
        self.width = width
        self.d_model = d_model
        self.pe: torch.Tensor = self._get_positional_encoding(d_model, height, width)

    def _get_positional_encoding(self, d_model, height, width):
        """
        :param d_model: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        :return: d_model*height*width position matrix
        """
        if d_model % 4 != 0:
            raise ValueError("Cannot use sin/cos positional encoding with "
                            "odd dimension (got dim={:d})".format(d_model))
        pe = torch.zeros(d_model, height, width)
        # Each dimension use half of d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                            -(math.log(10000.0) / d_model))
        pos_w = torch.arange(0., width).unsqueeze(1)
        pos_h = torch.arange(0., height).unsqueeze(1)
        pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
        pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)

        return pe
    
    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, channels, height, width)
        Returns:
            Tensor with positional encodings added, of shape (batch_size, channels, height, width)
        """
        batch_size, channels, height, width = x.size()
        # Ensure the input has the correct number of channels
        assert self.d_model == channels, "Dimension mismatch: d_model and input channels must be the same"
        # Add positional encodings to the input tensor
        x = x + self.pe.unsqueeze(0) #the unsqueeze() might not be necessary, idk
        
        return x


In [None]:
train_loader = get_dataloader(batch_size=10)

# Assuming the DataLoader for training is defined as train_loader
i = 0
for images, labels in train_loader:
    pil_image = transforms.ToPILImage()(images[0]).convert("RGB")
    # pil_image.show()
    # print(labels[0])    
    plt.imshow(images[0].permute(1, 2, 0))
    
    plt.title(f"Label: {labels[0]}")
    plt.show()
    if i > 10:
        break
    i += 1

In [14]:
import torch
model = models.densenet169(pretrained=True)
# Remove the final fully connected layer to get the final feature maps
model = nn.Sequential(*list(model.children())[:-1])

#to get the size of the output of the denseNet
image_tensor = torch.randn(1, 3, 800, 400)
output = model(image_tensor)
size = output.size()

model.add_module('PositionalEncoding2D', PositionalEncoding2D(1664, size[2], size[3])) # hardcoded this based on denseNet output size
#yet to automate that, the 7x7 is dependent on the input size of the model, it can be anything. Still need to automate this
image_tensor = torch.randn(1, 3, 800, 400)
print(model(image_tensor))





1664 25
tensor([[[[-4.2097e-03,  8.4160e-01,  9.0742e-01,  ...,  4.1541e-01,
           -5.4218e-01, -9.9898e-01],
          [-6.5510e-04,  8.4362e-01,  9.1039e-01,  ...,  4.1167e-01,
           -5.4177e-01, -9.9840e-01],
          [-7.3636e-04,  8.4410e-01,  9.1004e-01,  ...,  4.1122e-01,
           -5.4156e-01, -9.9990e-01],
          ...,
          [ 1.1227e-03,  8.4144e-01,  9.0773e-01,  ...,  4.1352e-01,
           -5.4382e-01, -9.9985e-01],
          [-1.8630e-03,  8.3972e-01,  9.0952e-01,  ...,  4.0974e-01,
           -5.4319e-01, -1.0022e+00],
          [-5.6795e-03,  8.4561e-01,  9.0938e-01,  ...,  4.1248e-01,
           -5.4239e-01, -1.0001e+00]],

         [[ 1.0001e+00,  5.4033e-01, -4.1668e-01,  ..., -9.1465e-01,
           -8.4142e-01,  5.4097e-03],
          [ 1.0010e+00,  5.3968e-01, -4.1372e-01,  ..., -9.1030e-01,
           -8.3687e-01,  2.5407e-03],
          [ 9.9948e-01,  5.4101e-01, -4.1600e-01,  ..., -9.0913e-01,
           -8.3588e-01,  4.4913e-03],
          ..