<a href="https://colab.research.google.com/github/AndreiS22/deep_learning_labs/blob/main/Local_GCD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
from torch.utils.data import Dataset, DataLoader

class MathTokenizer:
    def __init__(self):
        self.special_tokens = ['[PAD]', '[SOS]', '[EOS]', '[UNK]']
        self.pad_token = '[PAD]'
        self.sos_token = '[SOS]'
        self.eos_token = '[EOS]'
        self.unk_token = '[UNK]'
        self.vocab = self.special_tokens + ['+', '-'] + [str(d) for d in range(10)]
        self.token2id = {t: i for i, t in enumerate(self.vocab)}
        self.id2token = {i: t for i, t in enumerate(self.vocab)}

    def tokenize(self, text):
        return list(text)

    def encode(self, text):
        tokens = [self.sos_token] + self.tokenize(text) + [self.eos_token]
        return [self.token2id.get(t, self.token2id[self.unk_token]) for t in tokens]

    def decode(self, token_ids):
        return ''.join([self.id2token.get(i, self.unk_token) for i in token_ids
                    if i not in {self.token2id[t] for t in [self.pad_token, self.sos_token, self.eos_token]}])

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1)]

class LCMTransformer(nn.Module):
    def __init__(self, tokenizer, d_model=128, nhead=8, num_layers=3, max_length=512):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.vocab)

        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True
        )

        self.fc_out = nn.Linear(d_model, self.vocab_size)
        self.pad_id = tokenizer.token2id[tokenizer.pad_token]

    def forward(self, src, tgt):
        # src: (batch_size, src_len)
        # tgt: (batch_size, tgt_len)

        # Embedding + positional encoding
        src_emb = self.pos_encoder(self.embedding(src))
        tgt_emb = self.pos_encoder(self.embedding(tgt))

        # Create masks
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(src.device)

        # Key padding masks must be 2D (batch_size, seq_len)
        src_key_padding_mask = (src == self.pad_id)
        tgt_key_padding_mask = (tgt == self.pad_id)

        # Transformer forward
        output = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )

        return self.fc_out(output)

    def predict(self, a, b, max_length=10):
        """Predict gcd for integers a and b"""
        # Prepare input (format: "+a+b")
        input_str = f"+{a}+{b}"
        input_ids = self.tokenizer.encode(input_str)
        src = torch.tensor(input_ids).unsqueeze(0).to(next(self.parameters()).device)  # Add batch dim

        # Initialize target with SOS token
        sos_id = self.tokenizer.token2id[self.tokenizer.sos_token]
        tgt_ids = [sos_id]

        for _ in range(max_length):
            tgt = torch.tensor(tgt_ids).unsqueeze(0).to(src.device)

            with torch.no_grad():
                output = self(src, tgt)

            # Get most likely next token
            next_token = output.argmax(-1)[:, -1].item()
            tgt_ids.append(next_token)

            # Stop if EOS token is generated
            if next_token == self.tokenizer.token2id[self.tokenizer.eos_token]:
                break

        # Convert tokens to string and parse result
        pred_str = self.tokenizer.decode(tgt_ids[1:])  # Skip SOS token
        try:
            return int(pred_str.lstrip('+'))
        except ValueError:
            return None  # In case of invalid prediction

class LCMDataset(Dataset):
    def __init__(self, max_num=100, num_samples=100000, seed=42):
        self.rng = np.random.RandomState(seed)

        self.tokenizer = MathTokenizer()
        self.data = []

        for _ in range(num_samples):
            a, b = self.rng.randint(1, max_num, size=2)
            gcd = math.gcd(a, b)
            sign_a = '+'
            sign_b = '+'


            if a < 0:
                sign_a = '-'

            if b < 0:
                sign_b = '-'

            src = sign_a + f"{a}" + sign_b + f"{b}"
            tgt = f"+{gcd}"

            self.data.append((
                self.tokenizer.encode(src),
                self.tokenizer.encode(tgt)
            ))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.data[idx][0]),
            torch.tensor(self.data[idx][1])
        )

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src_batch, padding_value=tokenizer.token2id[tokenizer.pad_token], batch_first=True)
    tgt = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=tokenizer.token2id[tokenizer.pad_token], batch_first=True)
    return src, tgt

from tqdm import tqdm  # For progress bar

def compute_accuracy(model, dataloader, device):
    model.eval()
    perfect_sequences = 0
    total_sequences = 0

    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            # Get model predictions
            output = model(src, tgt_input)
            preds = output.argmax(-1)
            mask = (tgt_output != model.pad_id)
            seq_match = (preds == tgt_output) | ~mask
            perfect_sequences += seq_match.all(dim=1).sum().item()
            total_sequences += tgt.size(0)

    return perfect_sequences / max(total_sequences, 1)  # Avoid division by zero


In [None]:
# Initialize components
tokenizer = MathTokenizer()
validate_step = 2
layers = 4
heads = 8
hidden_dimension = 512
length = 512
lr = 10e-5
batch = 256
max_int = 1000000
sample_size = 10000
max_epoch = 1000
seed = None

model = LCMTransformer(tokenizer, d_model=hidden_dimension, nhead=heads, num_layers=layers, max_length=length)
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model.to(device)


# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token2id[tokenizer.pad_token])

counter = 0

for epoch in range(max_epoch):
    model.train()
    total_loss = 0
    train_dataset = LCMDataset(max_num=max_int, num_samples=sample_size, seed=seed)
    train_dataloader = DataLoader(train_dataset, batch_size=batch, collate_fn=collate_fn)
    total_sequences = 0
    perfect_sequences = 0

    for src, tgt in train_dataloader:
        src, tgt = src.to(device), tgt.to(device)

        # Prepare target input/output
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        optimizer.zero_grad()
        output = model(src, tgt_input)

        # Get prediction
        preds = output.argmax(-1)

        # Mask out padding tokens
        mask = (tgt_output != tokenizer.token2id[tokenizer.pad_token])

        loss = criterion(
            output.reshape(-1, output.size(-1)),
            tgt_output.reshape(-1)
        )

        # Compute accuracy

        with torch.no_grad():
            mask = (tgt_output != model.pad_id)
            seq_match = (preds == tgt_output) | ~mask
            perfect_sequences += seq_match.all(dim=1).sum().item()
            total_sequences += tgt.size(0)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    counter += 1

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_dataloader)}, Accuracy: {perfect_sequences / max(total_sequences, 1)}")
    if counter % validate_step == 0:
        # Accuracy validation
        validation_dataset = LCMDataset(max_num=max_int, num_samples=sample_size, seed=seed)
        validation_dataloader = DataLoader(validation_dataset, batch_size=batch, collate_fn=collate_fn)
        accuracy = compute_accuracy(model, validation_dataloader, device)
        print(f"Validation Accuracy: {accuracy}")

test_dataset = LCMDataset(max_num=max_int, num_samples=sample_size, seed=seed)
test_dataloader = DataLoader(test_dataset, batch_size=batch, collate_fn=collate_fn)
print(f"Test Accuracy: {compute_accuracy(model, test_dataloader, device)}")

In [None]:
model.eval()
model.predict(21,7)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
x = np.arange(1, 10)
y = 6 / (np.pi * x)**2
z = [sum(y[:int(i)]) for i in x]
fig, ax = plt.subplots(1, 2, figsize=(12,5))

ax[0].plot(x, y)
ax[1].plot(x, z)

ax[0].set_title("PMF of The Distribution of GCD")
ax[0].set_xlabel("GCD")
ax[0].set_ylabel("Probability")
ax[1].set_title("CMF of The Distribution of GCD")
ax[1].set_xlabel("GCD")
ax[1].set_ylabel("Probability")