In [None]:
import numpy as np
import torch
import math
import json
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader

In [None]:
path = "/Users/ibrahimbaldediallo/Documents/Code/Jarvis_project/vocab/instruction.txt"
model_path = "/Users/ibrahimbaldediallo/Documents/Code/Jarvis_project/TD/actor3.pth"
id_to_action_path = "/Users/ibrahimbaldediallo/Documents/Code/Jarvis_project/notebook/id_to_action.json"
action_to_id_path = "/Users/ibrahimbaldediallo/Documents/Code/Jarvis_project/notebook/action_to_id.json"

In [None]:
with open(id_to_action_path) as f:
    id_to_action_raw = json.load(f)

id_to_action = {int(k): v for k, v in id_to_action_raw.items()}

with open(action_to_id_path) as f:
    action_to_id_raw = json.load(f)

action_to_id = {k: int(v) for k, v in action_to_id_raw.items()}

In [None]:
vocab_size = len(action_to_id)
print(vocab_size)

In [None]:
encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [None]:
def set_seed(seed=50):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
set_seed()

In [None]:
device = torch.device('mps' if torch.mps.is_available() else 'cpu')

In [None]:
l = []
X = []
Y = []

with open(path, "r") as f:
    lines = f.readlines()
    for line in lines:
        line = line.lower()
        l.append(line.strip())

for data in l:
    split_data = data.split(";")
    x = split_data[0]
    y = split_data[1]
    try:
        y = y.replace("system preferences", "system_preferences")
        y = y.replace(" ", "")
        y = f"cmd+space {y} enter"
    except:
        pass

    X.append(x)
    Y.append(y)

In [None]:
print(X[0])
print(Y[0])

In [None]:
Y_idx = []
Y_idx_sentence = []
for s in Y:
    s = s.split()
    s = [action_to_id[word] for word in s]
    Y_idx.append(s)


In [None]:
from torch.nn.utils.rnn import pad_sequence

data = []
PAD = action_to_id.get("<PAD>", 0)
BOS = action_to_id.get("<BOS>", 1)
EOS = action_to_id.get("<EOS>", 2)
UNK = action_to_id.get("<UNK>", 3)

decoder_inputs = []
decoder_targets = []
encoder_inputs = []

# Construction brute des séquences
for x, y in zip(X, Y_idx):
    input_ids = encoder.encode(x, convert_to_tensor=True).to(device)
    decoder_input = [BOS] + y
    decoder_target = y + [EOS]

    encoder_inputs.append(input_ids)
    decoder_inputs.append(torch.tensor(decoder_input, dtype=torch.long))
    decoder_targets.append(torch.tensor(decoder_target, dtype=torch.long))

# Trouver la longueur max
max_len = max(max(len(seq) for seq in decoder_inputs),
              max(len(seq) for seq in decoder_targets))

# Padding des séquences
decoder_inputs_padded = pad_sequence(decoder_inputs, batch_first=True, padding_value=PAD)
decoder_targets_padded = pad_sequence(decoder_targets, batch_first=True, padding_value=PAD)

# Combine avec les entrées encodeur
for i in range(len(X)):
    data.append((encoder_inputs[i],
                 (decoder_inputs_padded[i], decoder_targets_padded[i])))



In [None]:
a = data[0]
b = data[1]
print(a[1][0].shape)
print(b[1][0].shape)

In [None]:
def collate_fn(batch):
    encoder_batch = torch.stack([item[0] for item in batch])
    decoder_input_batch = torch.stack([item[1][0] for item in batch])
    decoder_target_batch = torch.stack([item[1][1] for item in batch])
    return {
        "encoder_input": encoder_batch,
        "decoder_input": decoder_input_batch,
        "decoder_target": decoder_target_batch
    }

In [None]:
from torch.utils.data import DataLoader


# DataLoader
dataloader = DataLoader(
    data,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda batch: collate_fn(batch)
)

# Exemple d'une itération
for batch in dataloader:
    encoder_input = batch["encoder_input"]        # (B, D)
    decoder_input = batch["decoder_input"]        # (B, T)
    decoder_target = batch["decoder_target"]      # (B, T)

In [None]:
def decode(sequences, id_to_action, stop_token="<EOS>"):
    decoded_sequences = []
    for sequence in sequences:
        decoded = []
        for idx in sequence:
            token = id_to_action.get(idx, "<UNK>")
            if token == stop_token:
                break
            decoded.append(token)
        decoded_sequences.append(decoded)
    return decoded_sequences

In [None]:
decoded_targets = decode(batch["decoder_target"], id_to_action)
print(batch["decoder_target"].tolist())
print(decoded_targets)

In [None]:
for idx in batch["decoder_input"][0].tolist():
    print(idx, id_to_action.get(idx, "<UNK>"))

In [None]:
MAX_LEN = 32

In [None]:
class ResidualFFN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_blocks=2):
        super(ResidualFFN, self).__init__()
        
        # Projection initiale
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        
        # Blocs résiduels
        self.res_blocks = nn.ModuleList([
            ResidualBlock(hidden_dim) for _ in range(num_blocks)
        ])
        
        # Projection finale
        self.output_proj = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.input_proj(x)
        
        # Appliquer les blocs résiduels
        for block in self.res_blocks:
            x = block(x)
            
        return self.output_proj(x)
        
class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.3):
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim)
        )
        self.norm = nn.LayerNorm(dim)
        
    def forward(self, x):
        return self.norm(x + self.layers(x))
    

class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # (max_len, 1, dim)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (seq_len, batch_size, dim)
        x = x + self.pe[:x.size(0)]
        return x

class Actor(nn.Module):
    def __init__(self, encoder, dim, hidden, vocab_size, max_len=128):
        super().__init__()
        self.encoder = encoder  # pretrained SentenceTransformer
        self.rffn = ResidualFFN(384, hidden, dim)
        self.embedding = nn.Embedding(vocab_size, dim)
        self.pos_encoding = PositionalEncoding(dim, max_len=max_len)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=dim, nhead=16, dim_feedforward=hidden, dropout=0.3)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=2)  # Réduction de 6 à 2 couches
        self.final_projection = nn.Linear(dim, vocab_size)
        self.max_len = max_len
        self.dim = dim
        self.vocab_size = vocab_size
        self.load_state_dict(torch.load(model_path))

    def forward(self, x_texts, tgt):
        """
        x_texts: list of strings, len = batch_size
        tgt: tensor of shape (batch_size, seq_len)
        """
        batch_size = len(x_texts)
        
        # Encode input texts
        with torch.no_grad():
            x = self.encoder.encode(x_texts, convert_to_tensor=True)  # shape: (batch_size, 384)
        x = self.rffn(x)  # shape: (batch_size, dim)
       

        # Prepare target sequence
        tgt = tgt.to(device)
        tgt = self.embedding(tgt)  # (batch_size, seq_len, dim)
        tgt = tgt.permute(1, 0, 2)  # (seq_len, batch_size, dim)
        tgt = self.pos_encoding(tgt)  # add positional encoding

        # Create mask for autoregressive decoding
        seq_len = tgt.size(0)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(tgt.device)
        x = x.unsqueeze(0).repeat(seq_len, 1, 1)  # (seq_len, batch_size, dim)
        # Decode
        z = self.transformer_decoder(tgt, x, tgt_mask=tgt_mask)  # (seq_len, batch_size, dim)
        z = self.final_projection(z)  # (seq_len, batch_size, vocab_size)
        z = z.permute(1, 0, 2)  # (batch_size, seq_len, vocab_size)

        return z
    
    def forward_training(self, x, tgt):
        """
        x: encoder output (batch_size, dim)
        tgt: tensor of shape (batch_size, seq_len)
        """

        # Projette x dans le bon espace si nécessaire
        x = self.rffn(x)  # (batch_size, dim)

        # Embedding + Positional encoding
        tgt = self.embedding(tgt)  # (batch_size, seq_len, dim)
        tgt = tgt.permute(1, 0, 2)  # (seq_len, batch_size, dim)
        tgt = self.pos_encoding(tgt)

        # Memory (encoder output) doit être (seq_len_enc, batch_size, dim)
        # Ici on suppose x est global, donc on le répète
        x = x.unsqueeze(0)  # (1, batch_size, dim)

        # Masque auto-régressif pour le décodeur
        seq_len = tgt.size(0)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(tgt.device)

        # Transformer decoder
        z = self.transformer_decoder(tgt, x, tgt_mask=tgt_mask)  # (seq_len, batch_size, dim)
        z = self.final_projection(z)  # (seq_len, batch_size, vocab_size)
        z = z.permute(1, 0, 2)  # (batch_size, seq_len, vocab_size)

        return z

    
    @torch.no_grad()
    def generate(self, x_text:list[str], max_len=32, start_token_id=1, end_token_id=2):
        """
        x_text : liste de string
        Retourne une liste de listes contenant les ID générés
        """
        # Encode input texts
        with torch.no_grad():
            x = self.encoder.encode(x_text, convert_to_tensor=True)
        # Encoder: passe par rffn si nécessaire
        x = self.rffn(x)  # (batch_size, dim)
        memory = x.unsqueeze(0)  # (1, batch_size, dim)

        batch_size = x.size(0)
        device = x.device

        # Initialiser avec <BOS>
        generated = torch.full((batch_size, 1), start_token_id, dtype=torch.long, device=device)

        for _ in range(max_len):
            # Embed + position
            tgt_embed = self.embedding(generated)  # (batch_size, seq_len, dim)
            tgt_embed = tgt_embed.permute(1, 0, 2)  # (seq_len, batch_size, dim)
            tgt_embed = self.pos_encoding(tgt_embed)

            # Masque causal
            seq_len = generated.size(1)
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device)

            # Decode
            output = self.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask)
            logits = self.final_projection(output)  # (seq_len, batch_size, vocab_size)
            next_token_logits = logits[-1, :, :]  # dernier pas de temps → (batch_size, vocab_size)

            # Greedy : choisir l'indice du max
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # (batch_size, 1)

            # Ajouter à la séquence
            generated = torch.cat([generated, next_token], dim=1)

            # Option d'arrêt : si tous les batchs ont généré <EOS>
            if (next_token == end_token_id).all():
                break

        return generated  # (batch_size, seq_len_generated)

In [None]:
dim, hidden, = 512, 512

In [None]:
actor = Actor(encoder, dim, hidden, vocab_size).to(device)

In [None]:
print(actor)

In [None]:
print(batch["encoder_input"].shape)
print(batch["decoder_target"][0])
print(batch["decoder_target"][:, 1:][0])
print(batch["decoder_input"][0])
print(batch["decoder_input"][:, :-1][0])

In [None]:
p = actor.forward_training(batch["encoder_input"].to(device), batch["decoder_target"].to(device))
print(p)
print(p.shape)

In [None]:
def train_model(model, train_dataset, test_dataset, epochs, learning_rate):
    model = model.to(device)
    model.train()

    criterion = nn.CrossEntropyLoss(ignore_index=0)     
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    epoch_losses = []

    for epoch in range(epochs):
        total_loss = 0

        for batch in train_dataset:
            encoder_input = batch["encoder_input"].to(device)
            decoder_input = batch["decoder_input"].to(device)
            decoder_target = batch["decoder_target"].to(device)
            
            # tgt_input : tout sauf le dernier token
            decoder_input = decoder_input
            # tgt_output : tout sauf le premier token (ce qu’on doit prédire)
            decoder_target = decoder_target

            optimizer.zero_grad()
            output = model.forward_training(encoder_input, decoder_input)  # shape: (batch_size, seq_len, vocab_size)
            
            loss = criterion(output.reshape(-1, vocab_size), decoder_target.reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_dataset)
        epoch_losses.append(avg_loss)

        #val_loss = evaluate_model(model, test_dataset)
        scheduler.step()

        #print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_loss:.4f}, Validation Loss: {val_loss:.4f}")
        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {avg_loss:.4f}")

    return epoch_losses

In [None]:
train_model(actor, dataloader, None, epochs=1, learning_rate=1e-6)  # best lr for now is 1e-4, 4e-5, 2e-5 avec 2 epochs

In [None]:
torch.save(actor.state_dict(), 'actor4.pth')