In [1]:
from pathlib import Path
import random
import torch
import torch.nn as nn
import math

Path → safe file handling   
random → data inspection

In [2]:
DATA_RAW = Path("data/raw")     #original untouched data(read-only)
DATA_PROCESSED = Path("data/processed")     #cleaned + merged data will go(model reads from here)
TOKENIZER_DIR = Path("tokenizer")       #SentencePiece models will be saved here

DATA_PROCESSED.mkdir(parents=True, exist_ok=True)       #Creates folders only if they don’t exist
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)


In [3]:
# BPCC
bpcc_en = (DATA_RAW / "bpcc/train.en").read_text(encoding="utf-8").splitlines()
bpcc_sa = (DATA_RAW / "bpcc/train.sa").read_text(encoding="utf-8").splitlines()

# Samayik
sam_train_en = (DATA_RAW / "samayik/train.en").read_text(encoding="utf-8").splitlines()
sam_train_sa = (DATA_RAW / "samayik/train.sa").read_text(encoding="utf-8").splitlines()

sam_dev_en = (DATA_RAW / "samayik/dev.en").read_text(encoding="utf-8").splitlines()
sam_dev_sa = (DATA_RAW / "samayik/dev.sa").read_text(encoding="utf-8").splitlines()

sam_test_en = (DATA_RAW / "samayik/test.en").read_text(encoding="utf-8").splitlines()
sam_test_sa = (DATA_RAW / "samayik/test.sa").read_text(encoding="utf-8").splitlines()

# Load Itihasa (parallel, aligned)
iti_train_en = (DATA_RAW / "itihasa/train.en").read_text(encoding="utf-8").splitlines()
iti_train_sa = (DATA_RAW / "itihasa/train.sn").read_text(encoding="utf-8").splitlines()

iti_dev_en = (DATA_RAW / "itihasa/dev.en").read_text(encoding="utf-8").splitlines()
iti_dev_sa = (DATA_RAW / "itihasa/dev.sn").read_text(encoding="utf-8").splitlines()

iti_test_en = (DATA_RAW / "itihasa/test.en").read_text(encoding="utf-8").splitlines()
iti_test_sa = (DATA_RAW / "itihasa/test.sn").read_text(encoding="utf-8").splitlines()



Each file is read as UTF-8 (mandatory for Devanagari), split by newline, becomes List[str]

In [4]:
print("BPCC:", len(bpcc_en), len(bpcc_sa))

print("Samayik train:", len(sam_train_en), len(sam_train_sa))
print("Samayik dev:", len(sam_dev_en), len(sam_dev_sa))
print("Samayik test:", len(sam_test_en), len(sam_test_sa))

print("Itihasa train:", len(iti_train_en), len(iti_train_sa))
print("Itihasa dev:", len(iti_dev_en), len(iti_dev_sa))
print("Itihasa test:", len(iti_test_en), len(iti_test_sa))



BPCC: 98788 99424
Samayik train: 43493 43493
Samayik dev: 2416 2416
Samayik test: 2417 2417
Itihasa train: 75161 75161
Itihasa dev: 6148 6148
Itihasa test: 11721 11721


In [5]:
print(bpcc_en[0])
print(bpcc_sa[0]) 

There was no Mughal tradition of primogeniture, the systematic passing of rule, upon an emperor's death, to his eldest son.
चक्रवर्तिनः मृत्योः अनन्तरं तस्य शासनस्य व्यवस्थितरूपेण सङ्क्रमणस्य, मुघलपरम्परायाः ज्येष्ठपुत्राधिकारपद्धतिः नासीत्।


BPCC data is mismatched and not aligned, we'll use it for tokenizer training, not for model training.

In [6]:
i = random.randint(0, len(iti_train_en) - 1)
print("EN:", iti_train_en[i])
print("SA:", iti_train_sa[i])


EN: A king should also consult such minister as are free from the five kinds of deceit.
SA: संविनीय मदक्रोधौ मानमर्त्यां च निवृताः। नित्यं पञ्चोपधातीतैर्मन्त्रयेत् सह मन्त्रिभिः॥


In [7]:
i = random.randint(0, len(sam_train_en) - 1)
print("EN:", sam_train_en[i])
print("SA:", sam_train_sa[i])

EN: That is because we were in level 3 when we closed Ktouch.
SA: यतो हि यदा वयं के-टच पिहितवन्तः तदा तृतीयस्तरे एव आस्म।


In [8]:
# Translation training corpus
train_en = sam_train_en + iti_train_en
train_sa = sam_train_sa + iti_train_sa

print("Final TRAIN size:", len(train_en))


Final TRAIN size: 118654


In [9]:
# Tokenizer corpus 
tok_en = bpcc_en + train_en
tok_sa = bpcc_sa + train_sa

print("Tokenizer corpus:")
print("EN sentences:", len(tok_en))
print("SA sentences:", len(tok_sa))


Tokenizer corpus:
EN sentences: 217442
SA sentences: 218078


In [10]:
en_lengths = [len(s.split()) for s in train_en if s.strip()]

print("EN sentences:", len(en_lengths))
print("EN min:", min(en_lengths))
print("EN max:", max(en_lengths))
print("EN avg:", sum(en_lengths) / len(en_lengths))


EN sentences: 118654
EN min: 1
EN max: 1306
EN avg: 24.342188211101185


In [11]:
sa_lengths = [len(s) for s in train_sa if s.strip()]

print("SA sentences:", len(sa_lengths))
print("SA min:", min(sa_lengths))
print("SA max:", max(sa_lengths))
print("SA avg:", sum(sa_lengths) / len(sa_lengths))


SA sentences: 118654
SA min: 1
SA max: 2688
SA avg: 87.45049471572808


In [12]:
(DATA_PROCESSED / "tok_en.txt").write_text(
    "\n".join(tok_en),
    encoding="utf-8"
)


25313780

In [13]:
(DATA_PROCESSED / "tok_sa.txt").write_text(
    "\n".join(tok_sa),
    encoding="utf-8"
)


19646646

In [14]:
# Train English tokenizer ONLY if it doesn't exist
if not (TOKENIZER_DIR / "spm_en.model").exists():
    import sentencepiece as spm
    spm.SentencePieceTrainer.train(
        input=str(DATA_PROCESSED / "tok_en.txt"),
        model_prefix=str(TOKENIZER_DIR / "spm_en"),
        vocab_size=16000,
        model_type="unigram",
        character_coverage=1.0,
        pad_id=0,
        bos_id=1,
        eos_id=2,
        unk_id=3
    )
else:
    print("English tokenizer already exists. Skipping training.")


English tokenizer already exists. Skipping training.


In [15]:
# Train Sanskrit tokenizer ONLY if it doesn't exist
if not (TOKENIZER_DIR / "spm_sa.model").exists():
    import sentencepiece as spm
    spm.SentencePieceTrainer.train(
        input=str(DATA_PROCESSED / "tok_sa.txt"),
        model_prefix=str(TOKENIZER_DIR / "spm_sa"),
        vocab_size=32000,
        model_type="unigram",
        character_coverage=1.0,
        pad_id=0,
        bos_id=1,
        eos_id=2,
        unk_id=3
    )
else:
    print("Sanskrit tokenizer already exists. Skipping training.")


Sanskrit tokenizer already exists. Skipping training.


TOKENIZER (RUN ONCE ONLY)
DO NOT RETRAIN UNLESS DATA CHANGES

In [16]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


CUDA available: True
GPU name: NVIDIA GeForce RTX 2050


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
import sentencepiece as spm

sp_en = spm.SentencePieceProcessor()
sp_en.load("tokenizer/spm_en.model")

sp_sa = spm.SentencePieceProcessor()
sp_sa.load("tokenizer/spm_sa.model")


True

In [19]:
MAX_SRC_LEN = 128
MAX_TGT_LEN = 128
BATCH_SIZE = 8


train_en_ids = [
    sp_en.encode(s, out_type=int)[:MAX_SRC_LEN]
    for s in train_en
]

train_sa_ids = [
    sp_sa.encode(s, out_type=int)[:MAX_TGT_LEN]
    for s in train_sa
]


In [20]:
max(len(x) for x in train_sa_ids)


128

In [21]:
from torch.utils.data import Dataset
import torch

class TranslationDataset(Dataset):
    def __init__(self, src, tgt):
        self.src = src
        self.tgt = tgt

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.src[idx], dtype=torch.long),
            torch.tensor(self.tgt[idx], dtype=torch.long)
        )


In [22]:
from torch.nn.utils.rnn import pad_sequence

PAD_ID = 0

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)

    src_pad = pad_sequence(
        src_batch,
        batch_first=True,
        padding_value=PAD_ID
    )

    tgt_pad = pad_sequence(
        tgt_batch,
        batch_first=True,
        padding_value=PAD_ID
    )

    return src_pad, tgt_pad


In [23]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_ds = TranslationDataset(train_en_ids, train_sa_ids)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True
)


In [24]:
train_ds = TranslationDataset(train_en_ids, train_sa_ids)


In [25]:
print(len(train_loader))


3708


In [26]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len).unsqueeze(1)

        div_term = torch.exp(
            torch.arange(0, d_model, 2) *
            (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


- pe = torch.zeros(max_len, d_model):
Matrix storing positional vectors. Ex: 5000 × 512

- position = torch.arange(0, max_len).unsqueeze(1):
Creates: [[0],
 [1],
 [2],
 ...
]

- div_term:
Controls frequency scaling.
Lower dimensions → slow variation,
Higher dimensions → fast variation.
This allows model to infer relative distances.

- sin and cos:
Alternating dimensions: This creates unique positional signatures.
dim0 → sin
dim1 → cos
dim2 → sin
dim3 → cos

- register_buffer:
    moves with model to GPU, not trainable, saved in checkpoints

- forward: Adds positional information to embeddings.

In [27]:
#test
pe = PositionalEncoding(d_model=512)

x = torch.zeros(2, 10, 512)  # batch=2, seq_len=10
out = pe(x)

print(out.shape)


torch.Size([2, 10, 512])


In [28]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)

        # Attention weights
        attn = torch.softmax(scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(attn, V)

        return output, attn


In [29]:
#test
attn = ScaledDotProductAttention()

Q = torch.randn(2, 5, 64)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 64)

out, weights = attn(Q, K, V)

print(out.shape)
print(weights.shape)


torch.Size([2, 5, 64])
torch.Size([2, 5, 5])


A single attention head can only learn one type of relationship at a time.
But language has many relationships simultaneously. So instead of 1 attention we use multiple attention heads in parallel

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention()

        # Final output projection
        self.fc_out = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()

        x = x.view(batch_size, seq_len,
                   self.num_heads, self.head_dim)

        return x.transpose(1, 2)
        # (batch, heads, seq_len, head_dim)

    def combine_heads(self, x):
        batch_size, heads, seq_len, head_dim = x.size()

        x = x.transpose(1, 2).contiguous()

        return x.view(batch_size, seq_len,
                      heads * head_dim)

    def forward(self, Q, K, V, mask=None):

        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)

        out, attn = self.attention(Q, K, V, mask)

        out = self.combine_heads(out)

        out = self.fc_out(out)

        return out


**FEED FORWARD NETWORK (FFN)**

Attention mixes information between tokens.
But we also need a mechanism that:
transforms features, adds non-linearity, increases representational power

This is what the Feed Forward Network does.

In [31]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


- ReLU: Adds non-linearity. Without this, model becomes linear → weak.
- Dropout: Prevents overfitting. Important for MT tasks.

In [32]:
#test
ffn = FeedForward(512, 2048)

x = torch.randn(2, 10, 512)

out = ffn(x)

print(out.shape)


torch.Size([2, 10, 512])


**ENCODER LAYER (PRE-LAYERNORM):**

Input
  ->
LayerNorm
  ->
Multi-Head Self Attention
  ->
Residual Add
  ->
LayerNorm
  ->
Feed Forward Network
  ->
Residual Add
  ->
Output


In [33]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):

        # ---- Self Attention ----
        attn_out = self.self_attn(
            self.norm1(x),
            self.norm1(x),
            self.norm1(x),
            mask
        )

        x = x + self.dropout(attn_out)

        # ---- Feed Forward ----
        ffn_out = self.ffn(self.norm2(x))

        x = x + self.dropout(ffn_out)

        return x


- LayerNorm before attention(self.norm1(x)): Normalizes input distribution → stable training.
- Self-attention: Q = K = V = x
Each token attends to every other token.
- Residual connection: x = x + attn_out
Preserves original information and improves gradient flow.
- Feed Forward:
Each token independently transforms its features.
- Second residual connection: Again stabilizes deep stacking.

In [34]:
#test

enc_layer = EncoderLayer(
    d_model=512,
    num_heads=8,
    d_ff=2048
)

x = torch.randn(2, 10, 512)

out = enc_layer(x)

print(out.shape)


torch.Size([2, 10, 512])


In [35]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        num_heads,
        d_ff,
        num_layers,
        dropout=0.1,
        max_len=5000
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, x, mask=None):

        # Token embedding + scaling
        x = self.embedding(x) * math.sqrt(self.d_model)

        # Add positional encoding
        x = self.pos_encoding(x)

        x = self.dropout(x)

        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, mask)

        return x


- Embedding scaling: Prevents positional encoding from dominating embeddings early in training. This comes from the original paper.
- Positional encoding:
Adds order information.
- ModuleList: 
Stores multiple encoder layers.
Each layer has independent parameters.
- Sequential processing: 
Each layer refines representation further.

In [36]:
#test

encoder = Encoder(
    vocab_size=32000,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_layers=6
)

x = torch.randint(0, 32000, (2, 10))

out = encoder(x)

print(out.shape)


torch.Size([2, 10, 512])


The encoder only understands English.
The decoder generates Sanskrit.

- Encoder layer has:
Self-Attention → FFN
- Decoder layer has three blocks: ️Masked Self-Attention, Encoder–Decoder Attention,  Feed Forward Network

**Decoder Layer Structure:**
Input
 ->
LayerNorm
 ->
Masked Self Attention
 ->
Residual Add
 ->
LayerNorm
 ->
Cross Attention (Encoder output)
 ->
Residual Add
 ->
LayerNorm
 ->
Feed Forward
 ->
Residual Add
 ->
Output


In [37]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)

        self.ffn = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):

        # ---- Masked Self Attention ----
        self_attn_out = self.self_attn(
            self.norm1(x),
            self.norm1(x),
            self.norm1(x),
            tgt_mask
        )

        x = x + self.dropout(self_attn_out)

        # ---- Encoder-Decoder Attention ----
        cross_attn_out = self.cross_attn(
            self.norm2(x),
            enc_out,
            enc_out,
            src_mask
        )

        x = x + self.dropout(cross_attn_out)

        # ---- Feed Forward ----
        ffn_out = self.ffn(self.norm3(x))

        x = x + self.dropout(ffn_out)

        return x


- Masked Self Attention: Decoder attends only to previous generated tokens.
- Cross Attention:
Decoder looks at encoder output.
- Feed Forward:
Token-level transformation again.

In [38]:
#test

dec_layer = DecoderLayer(
    d_model=512,
    num_heads=8,
    d_ff=2048
)

x = torch.randn(2, 10, 512)
enc_out = torch.randn(2, 10, 512)

out = dec_layer(x, enc_out)

print(out.shape)


torch.Size([2, 10, 512])


In [39]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        num_heads,
        d_ff,
        num_layers,
        dropout=0.1,
        max_len=5000
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model

    def forward(self, x, enc_out, tgt_mask=None, src_mask=None):

        # Embedding + scaling
        x = self.embedding(x) * math.sqrt(self.d_model)

        # Positional encoding
        x = self.pos_encoding(x)

        x = self.dropout(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_out, tgt_mask, src_mask)

        return x


- Embedding scaling: Same reason as encoder, stabilizes early training.
- Decoder layers: Refines generation, aligns with encoder output better.
- Mask handling: causal mask, padding mask

In [40]:
#test

decoder = Decoder(
    vocab_size=32000,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_layers=6
)

tgt = torch.randint(0, 32000, (2, 10))
enc_out = torch.randn(2, 10, 512)

out = decoder(tgt, enc_out)

print(out.shape)


torch.Size([2, 10, 512])


In [41]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_heads=8,
        num_layers=6,
        d_ff=2048,
        dropout=0.1
    ):
        super().__init__()

        self.encoder = Encoder(
            src_vocab_size,
            d_model,
            num_heads,
            d_ff,
            num_layers,
            dropout
        )

        self.decoder = Decoder(
            tgt_vocab_size,
            d_model,
            num_heads,
            d_ff,
            num_layers,
            dropout
        )

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(
        self,
        src,
        tgt,
        src_mask=None,
        tgt_mask=None
    ):

        enc_out = self.encoder(src, src_mask)

        dec_out = self.decoder(
            tgt,
            enc_out,
            tgt_mask,
            src_mask
        )

        out = self.fc_out(dec_out)

        return out


- Encoder: Processes English sentence into contextual vectors.
- Decoder: Uses previous Sanskrit tokens, encoder output to generate next token representation.

In [42]:
#test

model = Transformer(
    src_vocab_size=16000,
    tgt_vocab_size=32000
)

src = torch.randint(0, 16000, (2, 10))
tgt = torch.randint(0, 32000, (2, 10))

out = model(src, tgt)

print(out.shape)


torch.Size([2, 10, 32000])


**MASKING**

In [43]:
def create_padding_mask(seq, pad_id=0):
    return (seq != pad_id).unsqueeze(1).unsqueeze(2)


In [44]:
def create_causal_mask(size):
    mask = torch.tril(torch.ones(size, size))
    return mask.bool()


**LOSS FUNCTION (LABEL SMOOTHING)**

In [45]:
criterion = nn.CrossEntropyLoss(
    ignore_index=0,
    label_smoothing=0.1
)


**OPTIMIZER**

In [46]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.98),
    eps=1e-9
)


In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)


In [48]:
scaler = torch.cuda.amp.GradScaler()

  scaler = torch.cuda.amp.GradScaler()


In [49]:
torch.cuda.empty_cache()


**TRAINING**

In [50]:
def train_epoch(model, loader, optimizer, criterion, scaler, device):

    model.train()
    total_loss = 0

    for step, (src, tgt) in enumerate(loader):


        src = src.to(device, non_blocking=True)
        tgt = tgt.to(device, non_blocking=True)

        # Teacher forcing split
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        # Masks
        src_mask = create_padding_mask(src).to(device)

        tgt_mask = create_padding_mask(tgt_input).to(device)
        causal_mask = create_causal_mask(
            tgt_input.size(1)
        ).to(device)

        tgt_mask = tgt_mask & causal_mask.unsqueeze(0).unsqueeze(0)


        optimizer.zero_grad()

        with torch.cuda.amp.autocast():

            output = model(
                src,
                tgt_input,
                src_mask,
                tgt_mask
            )

            output = output.reshape(-1, output.size(-1))
            tgt_output = tgt_output.reshape(-1)

            loss = criterion(output, tgt_output)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    return total_loss / len(loader)


In [51]:
EPOCHS = 10

for epoch in range(EPOCHS):

    loss = train_epoch(
        model,
        train_loader,
        optimizer,
        criterion,
        scaler,
        device
    )

    print(f"Epoch {epoch+1}: Loss = {loss:.4f}")


  with torch.cuda.amp.autocast():


Epoch 1: Loss = 22.0577
Epoch 2: Loss = 7.5415
Epoch 3: Loss = 7.0613
Epoch 4: Loss = 6.8028
Epoch 5: Loss = 6.6203
Epoch 6: Loss = 6.4790
Epoch 7: Loss = 6.3626
Epoch 8: Loss = 6.2658
Epoch 9: Loss = 6.1834
Epoch 10: Loss = 6.1120
