<a href="https://colab.research.google.com/github/JackBrandt/APIFunF24/blob/master/JacksTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## HW 11: Build a Transformer

In [None]:
import torch.nn as nn
import torch
import math
import random
from numpy.random import randint
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

alphabet=list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')
alphabet_length=len(alphabet)
char_to_idx = {char: idx for idx, char in enumerate(alphabet)}

def token_vectoriztion(token):
  if token not in alphabet:
    return torch.zeros(len(alphabet),device=device)
  index=alphabet.index(token)
  vector=torch.zeros(len(alphabet),device=device)
  vector[index]=1.0
  return vector

def token_list_vectorization(tokens):
    idx_list = []
    mask = []
    for token in tokens:
        if token in char_to_idx:
            idx_list.append(char_to_idx[token])
            mask.append(1)
        else:
            idx_list.append(0)
            mask.append(0)
    indices = torch.tensor(idx_list, device=device)
    one_hot = F.one_hot(indices, num_classes=alphabet_length).float()
    mask_tensor = torch.tensor(mask, device=device).unsqueeze(1).float()
    return one_hot * mask_tensor

def batch_token_vectorization(batch):
    return torch.stack([token_list_vectorization(tokens) for tokens in batch])

def vector_to_token(vector):
  return alphabet[vector.argmax()]

denom_vector=torch.arange(0,31)
denom_vector=torch.pow(1000,denom_vector/2)

def single_position(index):
  vector=torch.full((int(alphabet_length/2),),float(index),device=device)
  vector=torch.div(vector,denom_vector.to(device))
  sin_vec=torch.sin(vector)
  cos_vec=torch.cos(vector)
  vector=torch.cat((sin_vec,cos_vec),0)
  return vector

def positional_embedding(token_list):
  length=len(token_list)
  index_vector=torch.arange(length,device=device)
  positional_embedding=torch.stack([single_position(index) for index in index_vector])
  return positional_embedding

def positional_embedding_batch(length):
    half_dim = alphabet_length // 2  # Using half for sine and half for cosine.
    indices = torch.arange(0, length, device=device).unsqueeze(1).float()  # Shape: [length, 1]
    # Compute denominators in one vectorized step.
    denom = torch.pow(1000, torch.arange(0, half_dim, device=device).float() / 2)
    scaled = indices / denom  # Broadcast to shape: [length, half_dim]
    sin_embed = torch.sin(scaled)
    cos_embed = torch.cos(scaled)
    return torch.cat([sin_embed, cos_embed], dim=1)  # Final shape: [length, alphabet_length]

def softmax(tensor):
  softmax=torch.exp(tensor) / torch.sum(torch.exp(tensor), axis=0)
  return softmax


class jacks_transformer(nn.Module):
  def __init__(self, context_size=4):
    super(jacks_transformer, self).__init__()
    self.context_size=context_size
    self.alphabet=list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')
    self.alphabet_length=len(self.alphabet)
    self.encoder = encoder(num_layers=1, num_heads=1, alphabet_length=self.alphabet_length)
    self.output_NN=output_NN(alphabet_length)

  def get_batch(self,input_strings):
    batch=[list(string) for string in input_strings]
    for i,string in enumerate(batch):
      while len(string)<self.context_size:
        string.append(' ')
        if len(string)<self.context_size:
          string.insert(0,' ')
      batch[i]=string[:self.context_size]
    return batch

  def token_vectoriztion(self,batch):
    return batch_token_vectorization(batch)

  def positional_embedding(self,batch):
    pos_embed = positional_embedding_batch(self.context_size)
    return pos_embed.unsqueeze(0).expand(len(batch), -1, -1)

  def combined_embedding(self,input_strings):
    batch=self.get_batch(input_strings)
    return self.token_vectoriztion(batch)+self.positional_embedding(batch)

  def forward(self,input):
    embedding=self.combined_embedding(input)
    encoding=self.encoder(embedding)
    aggregated = torch.mean(encoding, dim=1)
    output=self.output_NN(aggregated)
    return output

class EncoderBlock(nn.Module):
    def __init__(self, alphabet_length, num_heads):
        super(EncoderBlock, self).__init__()
        self.attention_heads = nn.ModuleList([attention_head(alphabet_length) for _ in range(num_heads)])
        self.proj_weights = nn.Parameter(torch.rand((num_heads * alphabet_length, alphabet_length)))
        self.layer_norm1 = nn.LayerNorm(alphabet_length)
        self.ffnn = FFNN(alphabet_length)
        self.layer_norm2 = nn.LayerNorm(alphabet_length)

    def forward(self, x):
        head_outputs = [head(x) for head in self.attention_heads]
        concatenated = torch.cat(head_outputs, dim=2)
        projected = torch.matmul(concatenated, self.proj_weights)
        attention_out = self.layer_norm1(x + projected)
        ffnn_out = self.ffnn(attention_out)
        out = self.layer_norm2(attention_out + ffnn_out)
        return out

class encoder(nn.Module):
    def __init__(self, num_layers, num_heads, alphabet_length):
        super(encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderBlock(alphabet_length, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class attention_head(nn.Module):
  def __init__(self,alphabet_length):
    super(attention_head, self).__init__()
    self.query_weights=nn.Parameter(torch.rand((alphabet_length,alphabet_length)))
    self.key_weights=nn.Parameter(torch.rand((alphabet_length,alphabet_length)))
    self.value_weights=nn.Parameter(torch.rand((alphabet_length,alphabet_length)))

  def forward(self,batch,hide=True):
    query=torch.matmul(batch,self.query_weights)
    key=torch.matmul(batch,self.key_weights)
    value=torch.matmul(batch,self.value_weights)
    scores = torch.matmul(query, key.transpose(-2, -1))
    scale_factor = math.sqrt(query.size(-1))
    scores = scores / scale_factor
    attention_weights = torch.softmax(scores, dim=-1)
    return torch.bmm(attention_weights, value)


class FFNN(nn.Module):
    def __init__(self, alphabet_length, hidden_dim=None, num_layers=2):
        super(FFNN, self).__init__()
        if hidden_dim is None:
            hidden_dim = alphabet_length
        layers = []
        layers.append(nn.Linear(alphabet_length, hidden_dim))
        layers.append(nn.ReLU())
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, alphabet_length))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class output_NN(nn.Module):
    def __init__(self, input_dim, hidden_dim=32):
        super(output_NN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.network(x)

class alphabet_generator():
  def __init__(self):
    self.alphabet=list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789')
    self.alphabet_length=len(self.alphabet)
  def make_L_string(self):
    n=1
    #n=randint(1,5)
    i=randint(0,len(self.alphabet))
    j=randint(0,len(self.alphabet))
    while j==i:
      j=randint(0,len(self.alphabet))
    return self.alphabet[i]*n+self.alphabet[j]*n+self.alphabet[i]*n

  def make_almostL_string(self):
    n=1
    #n=randint(1,5)
    i=randint(0,len(self.alphabet))
    j=randint(0,len(self.alphabet))
    while j==i:
      j=randint(0,len(self.alphabet))
    x=randint(0,6)
    # Add error
    match x:
      case 0:
        return self.alphabet[i]*(n+1)+self.alphabet[j]*n+self.alphabet[i]*n
      case 1:
        return self.alphabet[i]*n+self.alphabet[j]*(n+1)+self.alphabet[i]*n
      case 2:
        return self.alphabet[i]*n+self.alphabet[j]*n+self.alphabet[i]*(n+1)
      case 3:
        return self.alphabet[i]*(n-1)+self.alphabet[j]*n+self.alphabet[i]*n
      case 4:
        return self.alphabet[i]*n+self.alphabet[j]*(n-1)+self.alphabet[i]*n
      case 5:
        return self.alphabet[i]*n+self.alphabet[j]*n+self.alphabet[i]*(n-1)

  def make_random_language_data(self,volume):
    train_strings=[]
    train_labels=[]
    for _ in range(volume):
      if random.randint(0,1)==1:
        train_strings.append(self.make_L_string())
        train_labels.append([1.0])
      else:
        train_strings.append(self.make_almostL_string())
        train_labels.append([0.0])
    return train_strings,train_labels

  def make_language_data(self,volume):
    return [self.make_L_string() for _ in range(volume)], [1 for _ in range(volume)]

  def make_not_language_data(self,volume):
    return [self.make_almostL_string() for _ in range(volume)], [0 for _ in range(volume)]


def train(model, optimizer, loss_fn, train_loader, num_epochs=100):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in train_loader:
            batch_strings, batch_labels = batch  # batch_strings is a list of strings, batch_labels is a tensor or list of floats.
            # Convert batch_labels to a tensor and move to device.
            batch_labels_tensor = torch.tensor(batch_labels, dtype=torch.float32, device=device).unsqueeze(1)

            optimizer.zero_grad()
            # Pass the raw strings in the batch to your model; your forward() will compute the embedding etc.
            predictions = model(batch_strings)  # shape: [batch_size, 1]
            loss = loss_fn(predictions, batch_labels_tensor)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

def evaluate(model, test_strings, test_labels, threshold=0.5):
    model.eval()
    with torch.no_grad():
      logits = model(test_strings)
      predictions = torch.sigmoid(logits)  # shape: [batch_size, 1]
    # Convert predictions to binary outputs based on the threshold.
    pred_labels = (predictions >= threshold).float()
    test_labels_tensor = torch.tensor(test_labels, dtype=torch.float32)
    i=0
    tally=0
    for string, true_label, pred, prob in zip(test_strings, test_labels_tensor, pred_labels, predictions):
      i+=1
      print(f"{i}: String: {string}, True: {int(true_label.item())}, Predicted: {int(pred.item())}, Probability: {prob.item():.4f}")
      if int(true_label.item())==int(pred.item()):
        tally+=1
    print(f"Accuracy: {tally/len(test_strings)}")

class LanguageDataset(Dataset):
    def __init__(self, strings, labels):
        """
        strings: list of raw string examples
        labels: list (or list of lists) of labels (e.g., [1.0] or [0.0])
        """
        self.strings = strings
        # Ensure labels are simple floats
        self.labels = [float(l[0]) if isinstance(l, list) else float(l) for l in labels]

    def __len__(self):
        return len(self.strings)

    def __getitem__(self, idx):
        # Return a tuple: (string, label)
        return self.strings[idx], self.labels[idx]

AG = alphabet_generator()
JT = jacks_transformer().to(device)
optimizer = optim.Adam(JT.parameters(), lr=0.002)
loss_fn = nn.BCEWithLogitsLoss()
training_data=AG.make_random_language_data(800)
train_strings=training_data[0]
train_labels=training_data[1]

# Assume train_strings and train_labels were generated by your generator.
train_dataset = LanguageDataset(train_strings, train_labels)
# Change the batch_size as needed; here we use 32.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

print("Training...")
train(JT, optimizer, loss_fn, train_loader, num_epochs=50)

example_strings=AG.make_random_language_data(32)
print()
print('Transformer Output:')

print()
print('Transformer Evaluation:')
evaluate(JT, example_strings[0], example_strings[1])

Using device: cuda
Training...


  batch_labels_tensor = torch.tensor(batch_labels, dtype=torch.float32, device=device).unsqueeze(1)


Epoch 10/50, Loss: 0.6945
Epoch 20/50, Loss: 0.5384
Epoch 30/50, Loss: 0.4886
Epoch 40/50, Loss: 0.4672
Epoch 50/50, Loss: 0.4647

Transformer Output:

Transformer Evaluation:
1: String: qCq, True: 1, Predicted: 1, Probability: 0.9320
2: String: HxHH, True: 0, Predicted: 1, Probability: 0.9190
3: String: FeF, True: 1, Predicted: 1, Probability: 0.9443
4: String: MPMM, True: 0, Predicted: 0, Probability: 0.3952
5: String: fmf, True: 1, Predicted: 1, Probability: 0.6544
6: String: p3pp, True: 0, Predicted: 1, Probability: 0.8718
7: String: iXi, True: 1, Predicted: 1, Probability: 0.9352
8: String: ZM, True: 0, Predicted: 0, Probability: 0.3805
9: String: z5z, True: 1, Predicted: 0, Probability: 0.4338
10: String: UVU, True: 1, Predicted: 1, Probability: 0.9187
11: String: eSe, True: 1, Predicted: 1, Probability: 0.8885
12: String: uP, True: 0, Predicted: 0, Probability: 0.4183
13: String: cTcc, True: 0, Predicted: 0, Probability: 0.3811
14: String: Mz, True: 0, Predicted: 0, Probability: