In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Desired image dimensions
IMAGE_SIZE = (299, 299)
# Max vocabulary size
MAX_VOCAB_SIZE = 2000000
# Fixed length allowed for any sequence
SEQ_LENGTH = 25
# Dimension for the image embeddings and token embeddings
EMBED_DIM = 128
# Number of self-attention heads
NUM_HEADS = 8
# Per-layer units in the feed-forward network
FF_DIM = 128 #1024 
# Shuffle dataset dim on tf.data.Dataset
SHUFFLE_DIM = 512
# Batch size
BATCH_SIZE = 64
# Numbers of training epochs
EPOCHS = 14

# Reduce Dataset
# If you want reduce number of train/valid images dataset, set 'REDUCE_DATASET=True'
# and set number of train/valid images that you want.
#### COCO dataset
# Max number train dataset images : 68363
# Max number valid dataset images : 33432
REDUCE_DATASET = False
# Number of train images -> it must be a value between [1, 68363]
NUM_TRAIN_IMG = 68363
# Number of valid images -> it must be a value between [1, 33432]
# N.B. -> IMPORTANT : the number of images of the test set is given by the difference between 33432 and NUM_VALID_IMG values.
# for instance, with NUM_VALID_IMG = 20000 -> valid set have 20000 images and test set have the last 13432 images.
NUM_VALID_IMG = 20000
# Data augmentation on train set
TRAIN_SET_AUG = True
# Data augmentation on valid set
VALID_SET_AUG = False
# If you want to calculate the performance on the test set.
TEST_SET = False

# Load train_data.json pathfile
train_data_json_path = "COCO_dataset/captions_mapping_train.json"
# Load valid_data.json pathfile
valid_data_json_path = "COCO_dataset/captions_mapping_valid.json"
# Load text_data.json pathfile
text_data_json_path  = "COCO_dataset/text_data.json"

# Save training files directory
SAVE_DIR = "save_train_dir/"

In [25]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, dense_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.dense_proj = nn.Linear(embed_dim, dense_dim)
        self.layernorm_1 = nn.LayerNorm(embed_dim)
        # print(self.layernorm_1)
        # self.inp_layer=nn.Linear(128,128)
    def forward(self, inputs, mask=None):
        # inputs=self.inp_layer(inputs)
        # print(inputs)
        inputs = self.dense_proj(inputs)
        attention_output, _ = self.attention(
            inputs, inputs, inputs, attn_mask=None
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        return proj_input

# Example Usage
block = TransformerEncoderBlock(EMBED_DIM, FF_DIM, NUM_HEADS)
# Assuming input tensor x of shape (batch_size, seq_length, embed_dim)
x = torch.randn(BATCH_SIZE, SEQ_LENGTH, EMBED_DIM)
print(x.shape)
output = block(x)
print(output.shape)


torch.Size([64, 25, 128])
torch.Size([64, 25, 128])


In [3]:
import torch
import torch.nn as nn

class PositionalEmbedding(nn.Module):
    def __init__(self, sequence_length, vocab_size, embed_dim):
        super().__init__()
        self.token_embeddings = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embed_dim
        )
        self.position_embeddings = nn.Embedding(
            num_embeddings=sequence_length, embedding_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def forward(self, inputs):
        batch_size, seq_length = inputs.size()
        positions = torch.arange(seq_length, device=inputs.device).unsqueeze(0).expand(batch_size, -1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return inputs != 0


In [30]:
# Example usage to test PositionalEmbedding class

# Assuming the input parameters
sequence_length = 128
vocab_size = 1000
embed_dim = 128

# Create an instance of the PositionalEmbedding class
pos_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim)

# Create a sample input tensor
import torch
inputs = torch.randint(0, vocab_size, (32, sequence_length))  # Assuming batch size of 32

# Compute positional embeddings
embedded_output = pos_embedding(inputs)

# Check the shape of the output
print("Shape of embedded output:", embedded_output.shape)
# print(embedded_output)


Shape of embedded output: torch.Size([32, 128, 128])


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_dim, ff_dim, num_heads, vocab_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.vocab_size = vocab_size
        
        # Multi-Head Attention layers
        self.attention_1 = nn.MultiheadAttention(embed_dim, num_heads)
        self.attention_2 = nn.MultiheadAttention(embed_dim, num_heads)
        
        # Feed-forward network
        self.dense_proj = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        
        # Layer normalization
        self.layernorm_1 = nn.LayerNorm(embed_dim)
        self.layernorm_2 = nn.LayerNorm(embed_dim)
        self.layernorm_3 = nn.LayerNorm(embed_dim)
        
        # Positional embedding
        self.embedding = PositionalEmbedding(SEQ_LENGTH, vocab_size, embed_dim)  # Adjusted
        
        # Output layer
        self.out = nn.Linear(embed_dim, vocab_size)
        
        # Dropout layers
        self.dropout_1 = nn.Dropout(0.1)
        self.dropout_2 = nn.Dropout(0.5)


    def forward(self, inputs, encoder_outputs, training, mask=None):
        inputs = self.embedding(inputs)
        causal_mask = self.get_causal_attention_mask(inputs)
        inputs = self.dropout_1(inputs)

        if mask is not None:
            padding_mask = mask.unsqueeze(1)
            combined_mask = torch.minimum(padding_mask, causal_mask)
        else:
            combined_mask = None
            padding_mask = None

        attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=combined_mask)
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, attn_mask=padding_mask)
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        proj_out = self.layernorm_3(out_2 + proj_output)
        proj_out = self.dropout_2(proj_out)

        preds = self.out(proj_out)
        return preds

    def get_causal_attention_mask(self, inputs):
        batch_size, sequence_length, _ = inputs.size()
        mask = torch.tril(torch.ones(sequence_length, sequence_length, device=inputs.device))
        mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [batch_size, seq_length, seq_length]
        return mask


In [6]:
# Example usage to test TransformerDecoderBlock class

# Assuming the input parameters
embed_dim = 6
ff_dim = 6
num_heads = 3
vocab_size = 23

# Create an instance of the TransformerDecoderBlock class
decoder_block = TransformerDecoderBlock(embed_dim, ff_dim, num_heads, vocab_size)

# Create sample inputs (batch_size, sequence_length)
import torch
inputs = torch.randint(0, vocab_size, (2, 25))  # Assuming batch size of 32 and sequence length of 25
encoder_outputs = torch.randn(2, 25, embed_dim)  # Example encoder outputs (batch_size, sequence_length, embed_dim)

# Ensure inputs tensor shape is (batch_size, sequence_length)
print("Shape of inputs tensor:", inputs.shape)

# Ensure encoder outputs tensor shape is (batch_size, sequence_length, embed_dim)
print("Shape of encoder outputs tensor:", encoder_outputs.shape)

# Compute predictions
predictions = decoder_block(inputs, encoder_outputs, training=True)

# Check the shape of the predictions
print("Shape of predictions:", predictions.shape)


Shape of inputs tensor: torch.Size([2, 25])
Shape of encoder outputs tensor: torch.Size([2, 25, 6])
Shape of predictions: torch.Size([2, 25, 23])


In [7]:
print(predictions)

tensor([[[-2.0606e-01, -1.7114e-01, -9.4247e-01,  ..., -6.6829e-01,
           5.7754e-01,  4.7863e-04],
         [-1.3249e+00, -6.6958e-01, -8.1786e-01,  ..., -7.1838e-01,
           8.7098e-01,  1.6100e+00],
         [ 3.4973e-01,  3.9239e-02, -2.0889e-01,  ..., -3.5631e-01,
           1.1971e-01, -4.3944e-01],
         ...,
         [-5.7583e-01, -1.1395e+00, -1.6274e+00,  ..., -1.2789e+00,
           1.0947e+00,  1.1039e+00],
         [-3.0398e-01,  3.2044e-01, -1.7796e+00,  ..., -1.2521e+00,
          -5.1979e-01, -7.2534e-01],
         [ 1.5497e+00, -5.7710e-01, -3.9896e-01,  ..., -1.5744e-01,
           1.2541e+00, -1.0405e+00]],

        [[-6.6756e-01,  4.2434e-01, -8.1626e-01,  ..., -2.3648e+00,
           8.8223e-02,  8.7488e-01],
         [-9.6958e-01, -9.0369e-01, -1.5932e+00,  ..., -1.0449e+00,
           8.2456e-01,  1.1915e+00],
         [ 5.8468e-01, -3.5219e-01, -8.7830e-01,  ..., -4.6111e-01,
           7.3699e-01, -5.6171e-01],
         ...,
         [-2.3316e-01, -6

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from swin_transformer import *

model_swin=SwinModel(embed_dim, num_patch_x, num_patch_y, num_heads, num_mlp, window_size, shift_size, qkv_bias, num_classes)
model_swin.load_state_dict(torch.load("/home/ayushh/MedMNIST_SWIN_Transformer/swin_classification_pytorch_model_weights.pth"))
device=torch.device('cuda')
model_swin.to(device)


class ImageCaptioningModel(nn.Module):
    def __init__(
        self,model_swin, encoder, decoder, num_captions_per_image=5,    # cnn_model,
    ):
        super().__init__()
        self.swin_model = model_swin #cnn_model
        self.encoder = encoder
        self.decoder = decoder
        self.num_captions_per_image = num_captions_per_image

    def forward(self, inputs):
        img_embed = self.swin_model(inputs[0])
        encoder_out = self.encoder(img_embed, False)
        decoder_out = self.decoder(inputs[2], encoder_out, training=inputs[1], mask=None)
        return decoder_out

    def calculate_loss(self, y_true, y_pred, mask):
        loss = F.cross_entropy(y_pred.transpose(1, 2), y_true, ignore_index=0, reduction='none')
        loss *= mask
        return loss.sum() / mask.sum()

    def calculate_accuracy(self, y_true, y_pred, mask):
        pred_ids = torch.argmax(y_pred, dim=-1)
        correct = (pred_ids == y_true) & mask
        return correct.sum().float() / mask.sum()

    def train_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        img_embed = self.swin_model(batch_img)

        for i in range(self.num_captions_per_image):
            encoder_out = self.encoder(img_embed, training=True)

            batch_seq_inp = batch_seq[:, i, :-1]
            batch_seq_true = batch_seq[:, i, 1:]

            mask = batch_seq_true != 0

            batch_seq_pred = self.decoder(
                batch_seq_inp, encoder_out, training=True, mask=mask
            )

            caption_loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
            caption_acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)

            batch_loss += caption_loss
            batch_acc += caption_acc

        loss = batch_loss
        acc = batch_acc / float(self.num_captions_per_image)

        return {"loss": loss, "acc": acc}

    def test_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0

        img_embed = self.swin_model(batch_img)

        for i in range(self.num_captions_per_image):
            encoder_out = self.encoder(img_embed, training=False)

            batch_seq_inp = batch_seq[:, i, :-1]
            batch_seq_true = batch_seq[:, i, 1:]

            mask = batch_seq_true != 0

            batch_seq_pred = self.decoder(
                batch_seq_inp, encoder_out, training=False, mask=mask
            )

            caption_loss = self.calculate_loss(batch_seq_true, batch_seq_pred, mask)
            caption_acc = self.calculate_accuracy(batch_seq_true, batch_seq_pred, mask)

            batch_loss += caption_loss
            batch_acc += caption_acc

        loss = batch_loss
        acc = batch_acc / float(self.num_captions_per_image)

        return {"loss": loss, "acc": acc}

    @property
    def metrics(self):
        return []  # Returning an empty list as PyTorch handles metrics differently
