In [1]:
import torch
import torch.nn as nn


class UNET(torch.nn.Module):
    def __init__(self,
               input_shape: int,
                 vocab_size: int) -> None:
        super(UNET, self).__init__()

        # Encoder


        self.encoder1 = nn.Sequential(
            nn.Conv2d(input_shape,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.encoder2 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.encoder3 = nn.Sequential(
            nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(2)

        self.encoder4 = nn.Sequential(
            nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool4 = nn.MaxPool2d(2)

        self.encoder5 = nn.Sequential(
            nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )



        # Decoder
        self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2,stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512,512,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2,stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )


        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2,stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(256, 128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )

        self.upconv4 = nn.ConvTranspose2d(128,64, kernel_size=2,stride=2)
        self.decoder4 = nn.Sequential(
            nn.Conv2d(128,64, kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64, kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True)
        )

        #self.classifier = nn.Sequential(nn.Conv2d(64, output_shape, 3, 1, padding=1),nn.ReLU())
        self.adaptive_pool = nn.AdaptiveAvgPool2d((20, 1))
        self.lstm = nn.LSTM(64, 512, num_layers=2, batch_first=True)
        self.fc = nn.Linear(512, vocab_size)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # Encoder
        x1 = self.encoder1(x)
        x = self.pool1(x1)
        #print("x1",x.shape)
        x2 = self.encoder2(x)
        x = self.pool2(x2)
        #print("x2",x.shape)
        x3 = self.encoder3(x)

        x = self.pool3(x3)
        #print("x3",x.shape)
        x4 = self.encoder4(x)

        x = self.pool4(x4)

        x = self.encoder5(x)


        # Decoder
        x = self.upconv1(x)

        x = torch.cat([x, x4], dim=1) # Skip connection
        x = self.decoder1(x)
        #print("x1",x.shape)
        x = self.upconv2(x)
        x = torch.cat([x, x3], dim=1) # Skip connection
        x = self.decoder2(x)
        #print("x2",x.shape)
        x = self.upconv3(x)
        x = torch.cat([x, x2], dim=1) # Skip connection
        x = self.decoder3(x)
        #print("x3",x.shape)
        x = self.upconv4(x)
        x = torch.cat([x, x1], dim=1) # Skip connection
        x = self.decoder4(x)
        #print("x4",x.shape)

        x = self.adaptive_pool(x)
        x = x.squeeze(-1)
        x = x.permute(0, 2, 1)
        #batch_size = x.size(0)
        #x = x.view(batch_size, 1024, -1).permute(0, 2, 1)  # [batch, seq, features]
        lstm_out, _ = self.lstm(x)
        outputs = self.fc(self.dropout(lstm_out))

        #x = self.classifier(x)
        #print(x.shape)
        #x = self.dropout(x)
        return outputs

In [2]:
from google.colab import drive
drive.mount('/content/drive')
folder_path = '/content/drive/MyDrive/5527Pdata/'
import os
print(os.listdir(folder_path))

Mounted at /content/drive
['synthetic_images', 'synthetic_labels.txt', 'test_images', 'test_labels.txt', 'train_images', 'train_labels.txt', 'valid_images', 'valid_labels.txt', 'symbols_images', 'symbols_labels.txt']


In [3]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from sklearn.model_selection import train_test_split


class ImageTextDataset(Dataset):
    def __init__(self, label_file_path, image_dir, transform=None, vocab=None, max_length=20):
        self.image_dir = image_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.vocab = vocab
        self.max_length = max_length

        # Read labels and image names
        self.image_names = []
        self.raw_labels = []  # Store raw text labels
        with open(label_file_path, 'r') as f:
            for line in f:
                if line.strip():
                    parts = line.strip().split('\t')
                    if len(parts) == 2:
                        self.image_names.append(parts[0])
                        self.raw_labels.append(parts[1])

        # Only process to tokens if vocab exists
        self.labels = []
        if self.vocab is not None:
            for label in self.raw_labels:
                tokens = label.split('\\')[:self.max_length]
                encoded = [self.vocab.get(token, 1) for token in tokens]  # 1 = <unk>
                padded = encoded + [0]*(self.max_length - len(encoded))
                self.labels.append(torch.tensor(padded))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)

        # Return encoded label if vocab exists, else raw text
        if self.vocab is not None:
            return image, self.labels[idx]
        else:
            return image, self.raw_labels[idx]



def create_train_test_split(dataset, test_size=0.2, random_state=42):
    """
    Split the dataset into training and testing sets

    Args:
        dataset (ImageTextDataset): The dataset to split
        test_size (float): Proportion of the dataset to include in the test split
        random_state (int): Random seed for reproducibility

    Returns:
        tuple: (train_dataset, test_dataset)
    """
    # Method 1: Using PyTorch's random_split
    train_size = int((1 - test_size) * len(dataset))
    test_size = len(dataset) - train_size

    train_dataset, test_dataset = random_split(
        dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(random_state)
    )

    return train_dataset, test_dataset

# Example usage
if __name__ == "__main__":


    # Example paths
    label_file_path = folder_path + "synthetic_labels.txt"
    image_dir = folder_path + "synthetic_images"

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize images for consistency
        transforms.ToTensor(),           # Convert PIL images to tensors
    ])

    # Create dataset with transform
    dataset = ImageTextDataset(label_file_path, image_dir, transform=transform)

    # Split dataset
    train_dataset, test_dataset = create_train_test_split(dataset, test_size=0.2)

    # Create dataloaders
    batch_size = 4
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Print dataset statistics
    print(f"Full dataset size: {len(dataset)}")
    print(f"Training set size: {len(train_dataset)}")
    print(f"Test set size: {len(test_dataset)}")

    # Try accessing a batch from the training dataloader
    print("\nExample training batch:")
    for images, labels in train_loader:
        print(f"Image batch shape: {images.shape}")
        print(f"Labels: {labels}")
        break  # Only print first batch

    # Try accessing a batch from the test dataloader
    print("\nExample test batch:")
    for images, labels in test_loader:
        print(f"Image batch shape: {images.shape}")
        print(f"Labels: {labels}")
        break  # Only print first batch


Full dataset size: 100
Training set size: 80
Test set size: 20

Example training batch:
Image batch shape: torch.Size([4, 3, 224, 224])
Labels: ('d_{FG}=d_F+d_G.', '\\varphi(x \\alpha, y \\beta) = \\sigma(\\alpha) \\, \\varphi(x, y) \\, \\beta .', 'B = \\operatorname{core}B.', '\\hat{\\mathbf x} = R_1^{-1} \\left(Q_1^\\textsf{T} \\mathbf{b}\\right)')

Example test batch:
Image batch shape: torch.Size([4, 3, 224, 224])
Labels: ('Az+B', '\\int (\\sin(x)+1)dx=\\int \\sin(x)dx + \\int 1dx=-\\cos(x)+x+C', 'm^*\\left(E, \\hat{B}, k_{\\hat{B}}\\right) = \\frac{\\hbar^2}{2\\pi} \\cdot \\frac{\\partial}{\\partial E} A\\left(E, \\hat{B}, k_{\\hat{B}}\\right)', '\\tan(-\\theta) = -\\tan \\theta')


In [20]:
import torch
import torch.nn as nn
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from collections import Counter
def build_vocabulary(dataset):
    # Define special tokens first
    special_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']
    vocabulary = {token: idx for idx, token in enumerate(special_tokens)}

    # Collect all tokens from dataset
    all_tokens = []
    for _, label in dataset:
        all_tokens.extend(label.split('\\'))

    # Create word counts while filtering out special tokens
    word_counts = Counter(
        token for token in all_tokens
        if token not in special_tokens
    )

    # Add regular tokens after special tokens
    for idx, (token, _) in enumerate(word_counts.most_common()):
        vocabulary[token] = idx + len(special_tokens)

    return vocabulary

temp_dataset = ImageTextDataset(label_file_path, image_dir)
vocab = build_vocabulary(temp_dataset)
vocab_size = len(vocab)

In [5]:
import torch
from collections import Counter


for images, labels in test_loader:
  # Sample text
  text = labels[0]

  # 1. Tokenization (split by space)
  tokens = text.split('\\')

  # 2. Vocabulary Creation
  word_counts = Counter(tokens)
  vocabulary = {token: index for index, token in enumerate(word_counts)}

  # 3. Numerical Encoding
  encoded_text = [vocabulary[token] for token in tokens]

  # 4. Tensor Conversion
  text_tensor = torch.tensor(encoded_text)

  print("Original Text:", text)
  print("Tokens:", tokens)
  print("Vocabulary:", vocabulary)
  print("Vocabulary size :", len(vocabulary))
  print("Encoded Text:", encoded_text)
  print("Tensor:", text_tensor)
  break

Original Text: Az+B
Tokens: ['Az+B']
Vocabulary: {'Az+B': 0}
Vocabulary size : 1
Encoded Text: [0]
Tensor: tensor([0])


In [28]:
# Full dataset with transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ImageTextDataset(label_file_path, image_dir, transform=transform, vocab=vocab)

# Split dataset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

In [31]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import models
class CNNTransformer(nn.Module):
    def __init__(self, vocab_size, max_seq_length=20):
        super().__init__()
        # Image encoder
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Sequential(
            nn.Linear(self.cnn.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Text decoder
        self.embedding = nn.Embedding(vocab_size, 512)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True),  # Add batch_first
            num_layers=3
        )
        self.fc = nn.Linear(512, vocab_size)
        self.max_seq_length = max_seq_length

        # Positional encoding
        self.positional_encoding = nn.Parameter(
            torch.zeros(1, max_seq_length, 512)
        )
        nn.init.trunc_normal_(self.positional_encoding)

    def forward(self, images, captions=None):
        # Encode images
        img_features = self.cnn(images)  # [batch, 512]
        img_features = img_features.unsqueeze(1)  # [batch, 1, 512]

        if captions is None:
            return self.generate(img_features)

        # Embed captions
        seq_length = captions.size(1)
        embeddings = self.embedding(captions)  # [batch, seq_len, 512]
        embeddings += self.positional_encoding[:, :seq_length, :]

        # Transformer decoding
        decoder_output = self.transformer_decoder(
            tgt=embeddings,
            memory=img_features
        )

        return self.fc(decoder_output)

    def generate(self, img_features, temperature=1.0):
        batch_size = img_features.size(0)
        device = img_features.device

        # Start with <sos> token
        outputs = torch.full((batch_size, 1), 2, dtype=torch.long, device=device)

        for _ in range(self.max_seq_length-1):
            embeddings = self.embedding(outputs) + self.positional_encoding[:, :outputs.size(1), :]
            print(img_features.shape)

            decoder_out = self.transformer_decoder(
                tgt=embeddings,
                memory=img_features
            )

            logits = self.fc(decoder_out[:, -1, :]) / temperature
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            outputs = torch.cat([outputs, next_token], dim=1)

        return outputs

In [47]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import models
class CNNTransformer(nn.Module):
    def __init__(self, vocab_size, max_seq_length=20):
        super().__init__()
        # Image encoder
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Sequential(
            nn.Linear(self.cnn.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Text decoder
        self.embedding = nn.Embedding(vocab_size, 512)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True),  # Add batch_first
            num_layers=3
        )
        self.fc = nn.Linear(512, vocab_size)
        self.max_seq_length = max_seq_length

        # Positional encoding
        self.positional_encoding = nn.Parameter(
            torch.zeros(1, max_seq_length, 512)
        )
        nn.init.trunc_normal_(self.positional_encoding)

    def forward(self, images, captions=None):
        # Encode images
        img_features = self.cnn(images)  # [batch, 512]
        img_features = img_features.unsqueeze(1)  # [batch, 1, 512]

        if captions is None:
            return self.generate(img_features)

        # Embed captions
        seq_length = captions.size(1)
        embeddings = self.embedding(captions)  # [batch, seq_len, 512]
        embeddings += self.positional_encoding[:, :seq_length, :]

        # Transformer decoding
        decoder_output = self.transformer_decoder(
            tgt=embeddings,
            memory=img_features
        )

        return self.fc(decoder_output)

    def generate(self, img_features, temperature=1.0):
        with torch.no_grad():
          img_features = self.cnn(images).unsqueeze(1)
        batch_size = img_features.size(0)
        device = img_features.device

        # Start with <sos> token
        outputs = torch.full((batch_size, 1), 2, dtype=torch.long, device=device)

        for _ in range(self.max_seq_length-1):
            embeddings = self.embedding(outputs) + self.positional_encoding[:, :outputs.size(1), :]
            print(img_features.shape)

            decoder_out = self.transformer_decoder(
                tgt=embeddings,
                memory=img_features
            )

            logits = self.fc(decoder_out[:, -1, :]) / temperature
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            outputs = torch.cat([outputs, next_token], dim=1)

        return outputs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNNTransformer(vocab_size=len(vocab)).to(device)

# Optimizer and loss
# Replace criterion with manual cross entropy
class StableCrossEntropy(nn.Module):
    def __init__(self, ignore_index=-100):
        super().__init__()
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        # Add numerical stability
        logits = logits - logits.max(dim=-1, keepdim=True)[0]
        log_probs = torch.log_softmax(logits, dim=-1)

        # Filter ignored indices
        mask = targets != self.ignore_index
        log_probs = log_probs[mask]
        targets = targets[mask]

        return -log_probs.gather(1, targets.unsqueeze(1)).mean()

# Replace criterion
criterion = StableCrossEntropy(ignore_index=0)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
#criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

best_loss = float('inf')
for epoch in range(100):
    # Training
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Shift labels for teacher forcing
        decoder_input = labels[:, :-1]
        decoder_target = labels[:, 1:]

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images, decoder_input)

        # Reshape for loss calculation
        outputs = outputs.view(-1, outputs.size(-1))
        targets = decoder_target.contiguous().view(-1)

        loss = criterion(outputs, targets)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        if torch.isnan(loss):
          print("NaN in model outputs!")
          continue
        train_loss += loss.item()

    # Validation
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            decoder_input = labels[:, :-1]
            decoder_target = labels[:, 1:]
            # Generate predictions
            outputs = model(images, decoder_input)

            # Calculate loss
            outputs = outputs.view(-1, outputs.size(-1))
            targets = decoder_target.contiguous().view(-1)
            loss = criterion(outputs, targets)
            if torch.isnan(loss):
              print("NaN in model outputs!")
              continue
            test_loss += loss.item()

    # Update scheduler
    avg_train_loss = train_loss / len(train_loader)
    avg_test_loss = test_loss / len(test_loader)
    scheduler.step(avg_test_loss)

    print(f"Epoch {epoch+1}/100")
    print(f"Train Loss: {avg_train_loss:.4f} | Test Loss: {avg_test_loss:.4f}")

    # Save best model
    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        #torch.save(model.state_dict(), 'best_transformer_model.pth')

NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
Epoch 1/100
Train Loss: 5.6560 | Test Loss: 5.5240
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!
NaN in model outputs!


KeyboardInterrupt: 