In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())


True
1


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet50_Weights
import math
from models.transformer import   

# 1. Positional Encoding for 2D spatial information
class PositionalEncoding2D(nn.Module):

    def __init__(self, d_model, height, width):
        super(PositionalEncoding2D, self).__init__()
        self.d_model = d_model

        # Create a positional encoding matrix
        pe = torch.zeros(height, width, d_model)
        y_pos = torch.arange(height, dtype=torch.float).unsqueeze(1)  # Shape: (height, 1)
        x_pos = torch.arange(width, dtype=torch.float).unsqueeze(0)   # Shape: (1, width)

        # Calculate the positional encodings
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))  # Shape: (d_model/2)

        # Apply sine and cosine to even and odd indices
        pe[:, :, 0::2] = torch.sin(y_pos * div_term.unsqueeze(0))  # Shape: (height, width, d_model/2)
        pe[:, :, 1::2] = torch.cos(x_pos * div_term)              # Shape: (height, width, d_model/2)

        # Reshape to (height * width, d_model)
        pe = pe.view(height * width, d_model)
        self.register_buffer('pe', pe)


    def forward(self, x):
        # Add positional encoding to the input tensor
        return x + self.pe[:x.size(1), :].unsqueeze(0) 

# 2. ResNet-50 Backbone to extract feature maps
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, d_model):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # Remove FC layer and AvgPool
        self.conv = nn.Conv2d(2048, d_model, kernel_size=1)  # Reduce channels to d_model

    def forward(self, x):
        x = self.backbone(x)  # Shape: [B, 2048, H, W]
        x = self.conv(x)      # Shape: [B, d_model, H, W]
        return x

# 3. Transformer Decoder with CTC Head
class TransformerWithCTC(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, num_classes):
        super(TransformerWithCTC, self).__init__()
        self.d_model = d_model

        # Image-to-Character Module
        self.cnn = ResNetFeatureExtractor(d_model)
        
        self.pos_encoder = PositionalEncoding2D(d_model, height=256, width=256)  # Assuming 16x16 feature maps
        
        # self.transformer_encoder = nn.TransformerEncoder(
        #     nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
        #     num_layers=num_encoder_layers
        # )

        # # Transformer Decoder for C2W
        # self.char_embeddings = nn.Embedding(num_classes, d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_decoder_layers
        )

        # Transformer for I2C
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            batch_first=True
        )

        # Character Embeddings as learnable parameters  
        self.char_embeddings = nn.Embedding(num_classes, d_model)

        # Output heads
        self.fc_class = nn.Linear(d_model, num_classes)  # Predict character class
        self.fc_position = nn.Linear(d_model, 2)        # Predict character position (x, y) 

        # CTC Decoder Head
        # self.fc_ctc = nn.Linear(d_model, num_classes)

    def forward(self, x, tgt_seq):
        # CNN Feature extraction
        cnn_out = self.cnn(x)                      # Shape: [B, d_model, H, W]
        cnn_out = cnn_out.flatten(2).permute(2, 0, 1)  # Reshape to [HW, B, d_model]
        print("Feature extracted. Size: ", cnn_out.shape)
        
        cnn_out = self.pos_encoder(cnn_out)        # Add 2D positional encoding
        print("Positional encoding. Size: ", cnn_out.shape)

        # Character Embeddings for the target sequence
        # tgt_embeddings = self.char_embeddings(tgt_seq)  # Shape: [T, B, d_model]
        # print("Character embeddings. Size: ", tgt_embeddings.shape)
        
        # # Encoder
        # memory = self.transformer_encoder(cnn_out)  # Shape: [HW, B, d_model
        # print("Encoder. Size: ", memory.shape)

        # # Decoder
        # tgt_embeddings = self.char_embeddings(tgt_seq)  # Shape: [T, B, d_model]
        # print("Character embeddings. Size: ", tgt_embeddings.shape)


        # return ctc_out
        # Transformer Encoder-Decoder
        transformer_out = self.transformer(cnn_out, tgt_seq)
        print("Transformer out. Size: ", transformer_out.shape)

        # Output: Class and Position predictions
        class_out = self.fc_class(transformer_out)      # Shape: [T, B, num_classes]
        position_out = self.fc_position(transformer_out)  # Shape: [T, B, 2]

        decoded_output = self.transformer_decoder(tgt_embeddings, memory)
        print("Decoder. Size: ", decoded_output.shape)
        # # CTC Head
        ctc_out = self.fc_ctc(decoded_output)  # Shape: [T, B, num_classes]
        # return class_out, position_out
        return ctc_out, class_out, position_out

# 4. CTC Decoder
def ctc_decoder(ctc_output, blank_label=0):
    """
    Simulates the CTC decoding process.
    Args:
        ctc_output: Tensor of shape [T, B, num_classes] (logits).
        blank_label: The index of the CTC blank label.
    Returns:
        Decoded output as a list of sequences.
    """
    decoded_sequences = []
    for batch in ctc_output.permute(1, 0, 2):  # Iterate over batch dimension
        seq = []
        prev_token = blank_label
        for timestep in batch:
            token = timestep.argmax().item()
            if token != blank_label and token != prev_token:
                seq.append(token)
            prev_token = token
        decoded_sequences.append(seq)
    return decoded_sequences


def encode_labels(labels, max_length=None):
    """
    Convert string labels to numerical tensor format.
    
    Args:
        labels (list of str): List of string labels to encode.
        max_length (int, optional): Maximum length for padding. If None, it will be determined from the labels.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, max_length) containing the encoded labels.
    """
    # Create a mapping from characters to indices
    char_to_index = {char: idx + 1 for idx, char in enumerate(sorted(set('abcdefghijklmnopqrstuvwxyz0123456789 ')))}
    char_to_index['<blank>'] = 0  # Add a blank token for CTC loss

    # Encode the labels
    encoded_labels = []
    for label in labels:
        encoded_label = [char_to_index[char] for char in label if char in char_to_index]
        encoded_labels.append(encoded_label)

    # Determine the maximum length if not provided
    if max_length is None:
        max_length = max(len(seq) for seq in encoded_labels)

    # Pad sequences to the maximum length
    padded_labels = [seq + [0] * (max_length - len(seq)) for seq in encoded_labels]  # 0 for padding

    return torch.tensor(padded_labels, dtype=torch.long)


In [2]:
from data.dataloader import SCUTLoader
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

# Assuming TransformerWithCTC and SCUTLoader are defined elsewhere

if __name__ == "__main__":
    # Hyperparameters
    BATCH_SIZE = 32
    IMAGE_SIZE = (3, 224, 224)
    D_MODEL = 512
    NHEAD = 8
    NUM_ENCODER_LAYERS = 3
    NUM_DECODER_LAYERS = 1
    NUM_CLASSES = 37  # 26 letters + 1 blank token + 0-9
    TGT_SEQ_LEN = 10   # Target sequence length

    DATA_IMG_PATH = "./data/SCUT-CTW1500/cropped_train_images2/"
    DATA_LABELS_PATH = "./data/SCUT-CTW1500/processed_labels2.csv"

    # Model
    model = TransformerWithCTC(
        d_model=D_MODEL,
        nhead=NHEAD,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        num_classes=NUM_CLASSES
    )
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Resize images to 256x256
        transforms.ToTensor(),  # Convert image to PyTorch tensor
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize based on ImageNet stats
    ])

    # Load dataset
    dataset = SCUTLoader(image_dir=DATA_IMG_PATH, label_dir=DATA_LABELS_PATH, transform=transform)
    print("Dataset size: ", len(dataset))
    
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # Move model to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Define loss function and optimizer
    criterion = nn.CTCLoss(blank=37)  # Assuming 37 is the index for the blank token
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (input_images, tgt_seq) in enumerate(train_loader):
            # input_images, tgt_seq = input_images.to(device), tgt_seq.to(device)
            input_images = input_images.to(device)
            tgt_seq = encode_labels(tgt_seq)
             
            tgt_seq = tgt_seq.to(device)
            optimizer.zero_grad()

            # Forward pass
            print("Input images shape:", input_images.shape)
            print("Target sequence shape:", tgt_seq.shape)

            print("Batch index: ", batch_idx)
            print("Input images: ", input_images[0])
            print("Target sequence: ", tgt_seq[0])

            predicted_word, class_out, position_outs = model(input_images, tgt_seq)  # Assuming model takes only input_images
            print(class_out)
            print(position_out)
            decoded_sequences = ctc_decoder(predicted_word, blank_label=0)
            print(decoded_sequences)

            # Calculate CTC loss
            # Note: You may need to adjust the shape of ctc_out and tgt_seq
            # Ensure tgt_seq is the correct shape for CTC loss
            loss = criterion(ctc_out, tgt_seq)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')

    # After training, you can evaluate the model on the test set
    # model.eval()
    # with torch.no_grad():
        # for input_images, tgt_seq in test_loader:
            # input_images = input_images.to(device)
            # ctc_out = model(input_images)
            # Decode the output
            # decoded_sequences = ctc_decoder(ctc_out, blank_label=0)
            # Process decoded_sequences as needed

   Unnamed: 0 img_file_name           label
0           0     00000.jpg   MEXICO, D. F.
1           0     00001.jpg  Tel.5-45-25-05
2           0     00002.jpg  Newton No. 136
3           0     00003.jpg          CEDROS
4           0     00004.jpg      PASTELERIA
Dataset size:  7703
Input images shape: torch.Size([32, 3, 128, 128])
Target sequence shape: torch.Size([32, 31])
Batch index:  0
Input images:  tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ...

RuntimeError: the batch number of src and tgt must be equal

In [3]:
import torch
# torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda
