In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from torch.utils.data import Dataset, DataLoader,random_split
from typing import Iterable, List
import torchtext
from sklearn.model_selection import train_test_split
import math
from tqdm import tqdm
# from google.colab import files

# import torch_xla
# import torch_xla.core.xla_model as xm
# Place-holders
token_transform = {}
vocab_transform = {}
# # Installing dependencies
# !pip install -U torchdata
# !pip install -U spacy
# !pip install 'portalocker>=2.0.0'
!python -m spacy download en_core_web_sm
!python -m spacy download fr_core_news_sm

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m74.9 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Collecting fr-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-3.7.0/fr_core_news_sm-3.7.0-py3-none-any.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m68.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('fr_core_news_sm')


In [2]:
batch_size = 64
block_size = 25
learning_rate = 3e-4
epochs = 5
eval_interval = 500
eval_iters = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 384
dropout = 0.2
no_of_heads = 6
n_layer = 6
device
SRC_LANGUAGE = 'Fr'
TGT_LANGUAGE = 'En'
# device = xm.xla_device()
device

'cuda'

In [3]:
data = pd.read_csv("/kaggle/input/englishtofrench/eng-fra.txt", sep="\t", header=None)
data = data.set_axis(['En','Fr'], axis = 1) # Rename indices
data.head()

Unnamed: 0,En,Fr
0,Go.,Va !
1,Run!,Cours !
2,Run!,Courez !
3,Wow!,Ça alors !
4,Fire!,Au feu !


In [4]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='fr_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter[language]:
        yield token_transform[language](data_sample)

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    print(ln)
    # Training data Iterator
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(data, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

Fr
En


In [5]:
print(vocab_transform["En"].lookup_token(200))
print(vocab_transform["En"].lookup_indices(["left"]))
vocab_size_src = len(vocab_transform[SRC_LANGUAGE])
vocab_size_tgt = len(vocab_transform[TGT_LANGUAGE])
print(f"Vocab size for {SRC_LANGUAGE} = {vocab_size_src}")
print(f"Vocab size for {TGT_LANGUAGE} = {vocab_size_tgt}")

left
[200]
Vocab size for Fr = 24554
Vocab size for En = 14875


In [6]:
# Step 1: Encode a sentence
def encode_sentence(sentence: str, language: str, tokenizer, vocab) -> List[int]:
    # Tokenize the sentence
    tokens = tokenizer(sentence)
    # Convert tokens to indices using vocabulary
    indices = vocab_transform[language].lookup_indices(tokens)
    return indices

# Step 2: Decode a sequence
def decode_sequence(indices: List[int], language: str, vocab) -> str:
    # Convert indices to tokens
    tokens = [vocab_transform[language].lookup_token(index) for index in indices]
    # Remove <bos> and <eos> tokens if present
    if tokens[0] == '<bos>':
        tokens = tokens[1:]
    if tokens[-1] == '<eos>':
        tokens = tokens[:-1]
    # Convert tokens to a sentence
    sentence = ""
    for token in tokens:
        # if token == '<bos>' or  token == '<eos>': or token == '<pad>':
        #     continue
        sentence = sentence + " " + token
    return sentence

# Example usage
sentence = "Je suis froid"
encoded = encode_sentence(sentence, SRC_LANGUAGE, token_transform[SRC_LANGUAGE], vocab_transform[SRC_LANGUAGE])
decoded = decode_sequence(encoded, SRC_LANGUAGE, vocab_transform[SRC_LANGUAGE])
print("Original sentence:", sentence)
print("Encoded sequence:", encoded)
print("Decoded sentence:", decoded)

Original sentence: Je suis froid
Encoded sequence: [6, 34, 448]
Decoded sentence:  Je suis froid


In [7]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_sample = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))
        src_sample = src_sample[:block_size]
        tgt_sample = text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n"))
        tgt_sample = tgt_sample[:block_size]
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)

        # src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        # tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch.T, tgt_batch.T

In [8]:
class CustomDataset(Dataset):
    def __init__(self, inputText, outputText):
        self.inputText = inputText
        self.outputText = outputText

    def __len__(self):
        return len(self.inputText)

    def __getitem__(self, idx):
        x = self.inputText[idx]
        outputText = self.outputText[idx]
        return x, outputText

dataset = CustomDataset(data["Fr"], data["En"])


In [9]:
for x, y in dataset:
    print("x = ", x)
    print("y = ", y)
    break

x =  Va !
y =  Go.


In [10]:
# Train, validation split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create a DataLoader to iterate over batches of data and performing preprocessing - By default produces batch first
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn = collate_fn, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, collate_fn = collate_fn)

In [11]:
i = 0
for x, y in train_dataloader:
    print(x.shape)
    for src in x:
        print("X input", decode_sequence(src, SRC_LANGUAGE, vocab_transform[SRC_LANGUAGE]))
        # print("Y input", decode_sequence(tgt, TGT_LANGUAGE, vocab_transform[TGT_LANGUAGE]))
    print("y shape = ", y.shape)
    for item in y:
        print("Y labels", decode_sequence(item, TGT_LANGUAGE, vocab_transform[TGT_LANGUAGE]))
    if i == 1:
        break
    i += 1


torch.Size([64, 19])
X input  Nous avons attendu des heures durant . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Il était résolu à y aller seul . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Je suis venue vous donner ceci . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  J’ aime les chameaux . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Voulez -vous que je vous véhicule ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Ces derniers temps , je ne l' ai pas souvent vu . <eos> <pad> <pad> <pad> <pad> <pad>
X input  Elles ont ri . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Vous avez l' air très fatigué . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Pourquoi ne pas passer la nuit ? <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
X input  Je n' a

In [None]:
sample = x[0]
print(sample.shape)
mask = (sample == PAD_IDX)
print(mask.shape)
print(decode_sequence(sample, SRC_LANGUAGE, vocab_transform[SRC_LANGUAGE]))

# h = Head(64, mask = False)
key = nn.Linear(1, 64, bias = False)
query = nn.Linear(1, 64, bias = False)
value = nn.Linear(1, 64, bias = False)
       
sample = sample.to(dtype=torch.float)
sample = sample.unsqueeze(-1)
k = key(sample)   # (B, T, head_size)
q = query(sample) # (B, T, head_size)
v = value(sample) # (B, T, head_size)

mask = mask.unsqueeze(1)
w = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T)
print(w.shape)
w = w.masked_fill(mask, float('-inf')) # (B, T, T)
w = F.softmax(w, dim = -1)

print(w)

In [12]:
# One head of self attention
class Head(nn.Module):
    def __init__(self, head_size, mask = True):
        super().__init__()
        # Query, key, and value are all linear layers.
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        # create a tril matrix of ones
        # PyTorch naming convention because the tril is not a parameter
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
        self.mask = mask
    def forward(self, x, query = None, key = None, value = None, x_padding_mask = None, context_padding_mask = None):
        B,T,C = x.shape

        if key is None:
            key = x
        if query is None:
            query = x
        if value is None:
            value = x

        k = self.key(key)   # (B, T, head_size)
        q = self.query(query) # (B, T, head_size)
        v = self.value(value) # (B, T, head_size)

        # print("Head: x shape",x.shape) # (64, 16, 384)
        # print("Head: query shape", q.shape) # (64, 16, 64)
        # print("Head key shape", k.shape) # (64, 19, 64)
        # print("Head value shape", v.shape) # (64, 19, 64)

        # Dot product the key and the query to get the weights
        w = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T)

        # Dividing by sqrt(head_size) for stability and making sure the variance stays close to zero
        w = w * (C ** -0.5)

        if self.mask: # This is the decoder mask (Masks future tokens)
            w = w.masked_fill((self.tril[:T, :T] == 0), float('-inf'))

        # w = (B, target_T, C)
        if x_padding_mask is not None: # This is the padding mask which masks the padding tokens
            x_padding_mask = x_padding_mask.unsqueeze(1) # (B, 1, T)
            # print("W shape", w.shape) # (B, T, T)
            w = w.masked_fill(x_padding_mask, float('-inf')) # (B, T, T)

        w = F.softmax(w, dim = -1)
        w = self.dropout(w)

        out = w @ v # (B, T, T) @ (B, T, C) = (B, T, C) cuz B stays the same so essentially its a (T, T) @ (T, C)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, no_of_heads, head_size, mask = True):
        super().__init__()
        self.mask = mask
        self.heads = nn.ModuleList([Head(head_size, self.mask) for _ in range(no_of_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, query = None, key = None, value = None, x_padding_mask = None):
        if key is None:
            key = x
        if query is None:
            query = x
        if value is None:
            value = x
        out = torch.cat([head(x, query, key, value, x_padding_mask = x_padding_mask) for head in self.heads], dim = -1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
        self.ln = nn.LayerNorm(n_embd)
    def forward(self, x):
            x = x + self.net(x)
            x = self.ln(x)
            return x

class GlobalSelfAttention(nn.Module):
    def __init__(self, n_embd, no_of_heads, mask):
        super().__init__()
        head_size = n_embd // no_of_heads
        self.mask = mask
        self.mha = MultiHeadAttention(no_of_heads, head_size, self.mask)
        self.ffwd = FeedForward(n_embd)
        self.ln = nn.LayerNorm(n_embd)

    def forward(self, x, padding_mask):
        x = x + self.mha(x, x_padding_mask = padding_mask)
        x = self.ln(x)
        return x

class CrossAttention(nn.Module):
    def __init__(self, n_embd, no_of_heads):
        super().__init__()
        head_size = n_embd // no_of_heads
        self.ca = MultiHeadAttention(no_of_heads, head_size, mask = False)
        self.ln = nn.LayerNorm(n_embd)

    def forward(self, x, context, x_padding_mask = None):
        # print("cross attn")
        x = x + self.ca(x=x, query=x, key=context, value=context, x_padding_mask = x_padding_mask)
        x = self.ln(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, n_embd, no_of_heads):
        super().__init__()
        self.sa = GlobalSelfAttention(n_embd, no_of_heads, mask = False)
        self.ffn = FeedForward(n_embd)

    def forward(self, x, padding_mask = None):
        x = self.sa(x, padding_mask)
        x = self.ffn(x)
        return x
# Need to change!!!!
# class PositionalEncoding(nn.Module):
#     def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000):
#         super(PositionalEncoding, self).__init__()
#         den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
#         pos = torch.arange(0, maxlen).reshape(1, maxlen)
#         pos_embedding = torch.zeros((maxlen, emb_size))
#         pos_embedding[:, 0::2] = torch.sin(pos * den)
#         pos_embedding[:, 1::2] = torch.cos(pos * den)
#         pos_embedding = pos_embedding.unsqueeze(0)

#         self.dropout = nn.Dropout(dropout)
#         self.register_buffer('pos_embedding', pos_embedding)

#     def forward(self, token_embedding):
#         return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(0)])

class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[EncoderLayer(n_embd=n_embd, no_of_heads=no_of_heads) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, idx, padding_mask = None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device))
        x = tok_emb + pos_emb
        for i in range(n_layer):
            x  = self.blocks[i](x, padding_mask = padding_mask)
        x = self.ln_f(x)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, n_embd, no_of_heads):
        super().__init__()
        self.masked_attn = GlobalSelfAttention(n_embd, no_of_heads, mask = True)
        self.crs_attn = CrossAttention(n_embd = n_embd, no_of_heads = no_of_heads)
        self.ffn = FeedForward(n_embd)

    def forward(self, x, context, x_padding_mask = None, context_padding_mask = None):
        x = self.masked_attn(x, x_padding_mask)
        x = self.crs_attn(x = x, context = context, x_padding_mask = context_padding_mask)
        x = self.ffn(x)
        return x


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super(Decoder, self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[DecoderLayer(n_embd=n_embd, no_of_heads=no_of_heads) for _ in range(n_layer)])

    def forward(self, x, context, x_padding_mask = None, context_padding_mask = None):
    # `x` is token-IDs shape (batch, target_seq_len)
        B, T = x.shape
        tok_emb = self.token_embedding_table(x) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device)) # (T, C)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.dropout(x)
        for i in range(n_layer):
            x  = self.blocks[i](x, context, x_padding_mask = x_padding_mask, context_padding_mask = context_padding_mask)
        # The shape of x is (batch_size, target_seq_len, d_model).
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size_src, vocab_size_tgt, n_embd):
        super().__init__()
        self.encoder = Encoder(vocab_size_src)
        self.decoder = Decoder(vocab_size_tgt)
        self.final_layer = nn.Linear(n_embd, vocab_size_tgt)

    def forward(self, inputs):
        context, outputs = inputs
        context = context.to(device)
        outputs = outputs.to(device)
        y_input = outputs[:, :-1]
        y_labels = outputs[:, 1:]

        context_padding_mask = (context == PAD_IDX)
        y_input_padding_mask = (y_input == PAD_IDX)

        context = self.encoder(context, padding_mask = context_padding_mask)
        x = self.decoder(x = y_input, context = context, x_padding_mask = y_input_padding_mask, context_padding_mask = context_padding_mask)  # (batch_size, target_len, d_model)
        
        logits = self.final_layer(x) # B, T, vocab_size_tgt

        B,T,C = logits.shape

        logits = logits.view(B * T, C)
        # print("logits shape = ", logits.shape)
        y_labels = y_labels.reshape(B * T)
        # print("y labels shape =", y_labels.shape)
        loss = F.cross_entropy(logits, y_labels, ignore_index=PAD_IDX)
        return logits, loss

    def generate(self, src_sentence, max_new_tokens):
        idx = torch.tensor([2], device = device).view(1, -1) # token 2 is <bos> token
        encoded_sentence = encode_sentence(src_sentence, SRC_LANGUAGE, token_transform[SRC_LANGUAGE], vocab_transform[SRC_LANGUAGE])
        encoded_sentence = torch.tensor(encoded_sentence, device = device).view(1, -1)
        # print(encoded_sentence.shape)
        context = self.encoder(encoded_sentence)
        #idx is (B,T)
        for _ in range(max_new_tokens):
            # Cropping the idx to the last block_size tokens
            idx_cond = idx[:, :block_size]
            x = self.decoder(x = idx_cond, context = context)
            logits = self.final_layer(x) # B, T, vocab_size_tgt
            logits = logits[:, -1, :] # Becomes (B, C)
            probs = F.softmax(logits, dim = -1)

            # Sampling from distribution
            idx_next = torch.multinomial(probs, num_samples = 1)
            if idx_next == 3: # Token 3 is <EOS>
                break
            idx = torch.cat((idx, idx_next), dim = 1)
        idx = idx.tolist()
        return decode_sequence(idx[0], TGT_LANGUAGE, vocab_transform[TGT_LANGUAGE])

In [15]:
# Unit test for Transformer
# transformer = Transformer(vocab_size_src = vocab_size_src, vocab_size_tgt = vocab_size_tgt, n_embd = n_embd)
transformer = transformer.to(device)
logits, loss = transformer((x, y))

print("Hi")
print(x.shape)
print(y.shape)
print(logits.shape)
print(loss)

Hi
torch.Size([64, 25])
torch.Size([64, 23])
torch.Size([1408, 14875])
tensor(0.1088, device='cuda:0', grad_fn=<NllLossBackward0>)


In [20]:
# Sample generation
# idx = torch.zeros((1,1), dtype = torch.long, device = device) # stands for the new line token \n
sentence = transformer.generate(src_sentence = "Je pense que tu as excellé.", max_new_tokens = 200)
print(sentence)

 You 're on your way .


In [None]:
@torch.no_grad()
def estimate_loss():
    transformer.eval()
    total_samples = len(val_dataloader)
    losses = torch.zeros(total_samples)
    for i, data in enumerate(val_dataloader):
        logits, loss = transformer(data)
        # print(loss)
        losses[i] = loss.item()
        # print(i)
    out = losses.mean()
    transformer.train()
    return out

In [None]:
optimizer = torch.optim.AdamW(transformer.parameters(), lr = 3e-4)

In [None]:
train_loss_graph = []
val_loss_graph = []

In [None]:
train_losses = torch.zeros(len(train_dataloader), requires_grad=False)
for iter in range(2):
    # Iterate over the batches in the train_dataloader
    for i, (X, Y) in tqdm(enumerate(train_dataloader)):
        X = X.to(device)
        Y = Y.to(device)
        optimizer.zero_grad(set_to_none = True)
        logits, loss = transformer((X, Y))
        loss.backward()
        optimizer.step()
        # xm.optimizer_step(optimizer)
        # xm.mark_step()
        train_losses[i] = loss
    losses = estimate_loss()
    print(f"Step {iter}: Train loss = {train_losses.mean()} Val loss = {losses}")
    train_loss_graph.append(train_losses.mean())
    val_loss_graph.append(losses)
    train_losses = torch.zeros(len(train_dataloader), requires_grad=False)
    torch.save(transformer.state_dict(), f'New{iter+4}.pth')

In [None]:
# download checkpoint file
torch.save(transformer.state_dict(), 'NewTransformer3.pth')

In [14]:
transformer = Transformer(vocab_size_src = vocab_size_src, vocab_size_tgt = vocab_size_tgt, n_embd = n_embd)
state_dict = torch.load('/kaggle/working/NewTransformer3.pth')
transformer.load_state_dict(state_dict)
# print(state_dict.keys())

<All keys matched successfully>

In [None]:
import gc
del t
gc.collect()