## LOCAL AND GLOBAL EMBEDDINGS

This file its copy of the file **Transformers** but uses local and global context for each block of the scripts

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter

import math
import numpy as np
import re
import os
from torch.nn.utils.rnn import pad_sequence


os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

torch.manual_seed(23)

<torch._C.Generator at 0x767bf80d0f90>

In [2]:
def check_gpu():
    if torch.cuda.is_available():
        print("CUDA está disponible.")
        print(f"Hay {torch.cuda.device_count()} GPU(s) disponible(s).")
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("CUDA no está disponible. No hay GPU accesible.")

check_gpu()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')


CUDA está disponible.
Hay 1 GPU(s) disponible(s).
GPU 0: NVIDIA GeForce RTX 2060


In [3]:
MAX_SEQ_LEN = 128 # max num of words per phrase for translate

In [4]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len = MAX_SEQ_LEN):
        super().__init__()
        self.pos_embed_matrix = torch.zeros(max_seq_len, d_model, device=device)
        token_pos = torch.arange(0, max_seq_len, dtype = torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() 
                             * (-math.log(10000.0)/d_model))
        self.pos_embed_matrix[:, 0::2] = torch.sin(token_pos * div_term)
        self.pos_embed_matrix[:, 1::2] = torch.cos(token_pos * div_term)
        self.pos_embed_matrix = self.pos_embed_matrix.unsqueeze(0).transpose(0,1)
        
    def forward(self, x):
#         print(self.pos_embed_matrix.shape)
#         print(x.shape)
        return x + self.pos_embed_matrix[:x.size(0), :]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model = 512, num_heads = 8):
        super().__init__()
        assert d_model % num_heads == 0, 'Embedding size not compatible with num heads'
        
        self.d_v = d_model // num_heads
        self.d_k = self.d_v
        self.num_heads = num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, Q, K, V, mask = None):
        batch_size = Q.size(0)
        '''
        Q, K, V -> [batch_size, seq_len, num_heads*d_k]
        after transpose Q, K, V -> [batch_size, num_heads, seq_len, d_k]
        '''
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2 )
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2 )
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2 )
        
        weighted_values, attention = self.scale_dot_product(Q, K, V, mask)
        weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads*self.d_k)
        weighted_values = self.W_o(weighted_values)
        
        return weighted_values, attention
        
        
    def scale_dot_product(self, Q, K, V, mask = None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = F.softmax(scores, dim = -1)
        weighted_values = torch.matmul(attention, V)
        
        return weighted_values, attention
        

class PositionFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))
    
class EncoderSubLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.droupout1 = nn.Dropout(dropout)
        self.droupout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask = None):
        attention_score, _ = self.self_attn(x, x, x, mask)
        x = x + self.droupout1(attention_score)
        x = self.norm1(x)
        x = x + self.droupout2(self.ffn(x))
        return self.norm2(x)

class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderSubLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderSubLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, target_mask=None, encoder_mask=None):
        attention_score, _ = self.self_attn(x, x, x, target_mask)
        x = x + self.dropout1(attention_score)
        x = self.norm1(x)
        
        encoder_attn, _ = self.cross_attn(x, encoder_output, encoder_output, encoder_mask)
        x = x + self.dropout2(encoder_attn)
        x = self.norm2(x)
        
        ff_output = self.feed_forward(x)
        x = x + self.dropout3(ff_output)
        return self.norm3(x)
        
class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderSubLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, encoder_output, target_mask, encoder_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, target_mask, encoder_mask)
        return self.norm(x)

In [5]:
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers,
                 input_vocab_size, max_len=MAX_SEQ_LEN, dropout=0.1):
        super().__init__()
        self.encoder_embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_embedding = PositionalEmbedding(d_model, max_len)
        self.encoder = Encoder(d_model, num_heads, d_ff, num_layers, dropout)
        self.sep_token_id = 2

        # Embeddings
        self._cached_source_embeddings = None
        self._all_embeddings = None
        
    def forward(self, source):
        # Encoder mask
        source_mask = self.mask(source)
        # Embedding and positional Encoding
        source = self.encoder_embedding(source) * math.sqrt(self.encoder_embedding.embedding_dim)
        self._cached_source_embeddings = source
        source = self.pos_embedding(source)
        # Encoder
        encoder_output = self.encoder(source, source_mask)
        
        return encoder_output
        
    def get_embeddings(self):
        if self._all_embeddings is None:
            raise ValueError("Embeddings not computed yet. Call forward() first.")
        return self._all_embeddings
    
    def mask(self, source):
        source_mask = (source != 0).unsqueeze(1).unsqueeze(2)
        
        if self.sep_token_id is not None:
            seq_len = source.size(1)
            sep_positions = (source == self.sep_token_id).nonzero(as_tuple=True)[1]
            
            attention_block_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=source.device)

            last_sep_pos = -1
            for sep_pos in sep_positions:
                attention_block_mask[last_sep_pos + 1 : sep_pos, last_sep_pos + 1 : sep_pos] = 0
                last_sep_pos = sep_pos

            source_mask = source_mask & attention_block_mask.unsqueeze(0).unsqueeze(0)

        return source_mask

        

## Simple test

In [6]:
seq_len_source = 10
seq_len_target = 10
batch_size = 2
input_vocab_size = 50
target_vocab_size = 50

source = torch.randint(1, input_vocab_size, (batch_size, seq_len_source))

d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6

model = Transformer(d_model, num_heads, d_ff, num_layers,
                   input_vocab_size, max_len=MAX_SEQ_LEN, dropout=0.1)


model = model.to(device)
source = source.to(device)

output = model(source)
#Expected output shape -> [batch, seq_len_target, target_vocab_size] i.e [2, 10, 50]
print(f'output.shape {output.shape}')
print(source)

output.shape torch.Size([2, 10, 512])
tensor([[39, 11,  6, 31, 11, 42, 49, 35, 10, 20],
        [24, 12, 35, 20, 30,  6, 23, 39, 48,  5]], device='cuda:0')


## DATA PREPROCESSING: Creation of scripts list with all project scripts

In [7]:
import pandas as pd
import os
from zipfile import ZipFile, BadZipFile
import json

metrics = pd.read_csv('metrics_attr.csv')

# Filter by Action Genre
metrics_action = metrics[(metrics['Main Genre'] == 'Action')]
metrics_non_action = metrics[(metrics['Main Genre'] != 'Action')]


# Check the filenames for collect
filenames_action = list(metrics_action['Name'])
filenames_non_action = list(metrics_non_action['Name'])
# Create txt for made the shell script
with open('filenames_action_global.txt', 'w') as names:
     for name in filenames_action:
        names.write(name + '\n')


#for project in filenames_action:
 #   sb3_path = f'./projects_sb3/{project}'
  #  if os.path.isfile(sb3_path):
   #     shutil.copy(sb3_path, './sb3_action_global')
    #    print(f'The project {project} has been success copy')
    #else:
     #   print(f'The project {project} doesnt exists')

In [8]:


# ----------------- SCRIPTS GLOBAL (local context) ----------------------------------
dict_total_blocks = {}
scripts_global = []
list_total_blocks = []
print(len(filenames_action))
for project in filenames_action:
    sb3_path = os.path.join('.','sb3_action_global',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            for idx, block_list in seqs.items():
                if block_list != []:
                    scripts_global.append(" ".join(block_list))

# ----------------- SCRIPTS TARGET ----------------------------------

scripts_train1 = []
scripts_train2 = []
for idx, project in enumerate(filenames_action):
    sb3_path = os.path.join('.','sb3_action_global',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            for block_list in seqs.values():
                if block_list != []:
                    if idx < math.floor(len(filenames_action)/2):
                        scripts_train1.append(" ".join(block_list))
                    else:
                        scripts_train2.append(" ".join(block_list))


# ----------------- SCRIPTS NEGTATIVE  ----------------------------------
scripts_train3 = []
for project in filenames_non_action:
    sb3_path = os.path.join('.','sb3_non_action',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            for idx, block_list in seqs.items():
                if block_list != []:
                    scripts_train3.append(" ".join(block_list))

# ----------------- SCRIPTS GLOBAL (global context)  ----------------------------------

dict_total_blocks = {}
scripts_global_global = []
list_total_blocks = []
print(len(filenames_action))
print(filenames_action[0])
for project in filenames_action:
    
    sb3_path = os.path.join('.','sb3_action_global',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            global_seqs = []
            #print("project:", project)
            #print(seqs)
            #print("Sprite",sprite)
            
            for idx, block_list in seqs.items():
                if block_list != []:
                    global_seqs.append(" ".join(block_list))
                    global_seqs.append("<SCRIPT_END>")
            if global_seqs != []:
                global_seqs.pop(-1)
            
            scripts_global_global.append(" ".join(global_seqs))


# ----------------- SCRIPTS TARGET (POSITIV) (global context)  ----------------------------------

scripts_train1_global = []
scripts_train2_global = []
dict_total_blocks = {}
list_total_blocks = []
print(len(filenames_action))
for idx, project in enumerate(filenames_action):
    
    sb3_path = os.path.join('.','sb3_action_global',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            global_seqs = []
            #print("project:", project)
            #print(seqs)
            #print("Sprite",sprite)
            
            for block_list in seqs.values():
                if block_list != []:  
                    global_seqs.append(" ".join(block_list))
                    global_seqs.append("<SCRIPT_END>")
            if global_seqs != []:
                global_seqs.pop(-1)
        
            if int(idx) < math.floor(len(filenames_action)/2):
                scripts_train1_global.append(" ".join(global_seqs))
            else:
                scripts_train2_global.append(" ".join(global_seqs))

# ----------------- SCRIPTS NEGATIVE (global context) -----------------------------

dict_total_blocks = {}
scripts_train3_global= []
list_total_blocks = []
print(len(filenames_non_action))
for project in filenames_non_action:
    
    sb3_path = os.path.join('.','sb3_non_action',project)
    if os.path.isfile(sb3_path):
        #print(project)
        json_project = load_json_project(sb3_path)
        dict_total_blocks = process(json_project)

        for sprite, seqs in dict_total_blocks.items():
            global_seqs = []
            #print("project:", project)
            #print(seqs)
            #print("Sprite",sprite)
            
            for idx, block_list in seqs.items():
                if block_list != []:
                    global_seqs.append(" ".join(block_list))
                    global_seqs.append("<SCRIPT_END>")
            if global_seqs != []:
                global_seqs.pop(-1)
            
            scripts_train3_global.append(" ".join(global_seqs))


# universal_scripts -> scripts of all genres for local context
# scripts_global -> global scripts (local context)... each index of the list its an scripts ACTION
# scripts_train1 -> src
# scritps_train2 -> trg
# scritps_train3 -> negative
# scripts_global_global -> global scripts (global context)... each index of the list its all project scripts 

# --------------------- LOCAL CONTEXT ---------------------------------------
min_len = min(len(scripts_train1), len(scripts_train2), len(scripts_train3))
scripts_train1 = scripts_train1[:min_len]
scripts_train2 = scripts_train2[:min_len]
scripts_train3 = scripts_train3[:min_len]

scripts_universal = scripts_train1 + scripts_train2 + scripts_train3
scripts_global = ['<sos> ' + script + ' <eos>' for script in scripts_global]
scripts_train1 = ['<sos> ' + script + ' <eos>' for script in scripts_train1]
scripts_train2 = ['<sos> ' + script + ' <eos>' for script in scripts_train2]
scripts_train3 = ['<sos> ' + script + ' <eos>' for script in scripts_train3]

# --------------------- GLOBAL CONTEXT ---------------------------------------
min_len = min(len(scripts_train1_global), len(scripts_train2_global), len(scripts_train3_global))
scripts_train1_global = scripts_train1_global[:min_len]
scripts_train2_global = scripts_train2_global[:min_len]
scripts_train3_global = scripts_train3_global[:min_len]

scripts_universal_global = scripts_train1_global + scripts_train2_global + scripts_train3_global
scripts_global_global = ['<sos> ' + script + ' <eos>' for script in scripts_global_global]
scripts_train1_global = ['<sos> ' + script + ' <eos>' for script in scripts_train1_global]
scripts_train2_global = ['<sos> ' + script + ' <eos>' for script in scripts_train2_global]
scripts_train3_global = ['<sos> ' + script + ' <eos>' for script in scripts_train3_global]

312
312
Abby and Grace's project.sb3
312
328


### CREATION OF DATASET FOR **TRAIN**

In [9]:
def build_vocab(scripts):
    blocks = [block for script in scripts for block in script.split() ]
    block_count = Counter(blocks)
    sorted_block_counts = sorted(block_count.items(), key=lambda x:x[1], reverse=True)
    block2idx = {
        '<pad>': 0,
        '<unk>': 1,
        '<SCRIPT_END>': 2
    }
    for idx, (block, _) in enumerate(sorted_block_counts, 3):
        block2idx[block] = idx
    idx2block = {idx: block for block, idx in block2idx.items()}
    return block2idx, idx2block

class TripletDataset(Dataset):
    def __init__(self, src_sentences, trg_sentences, neg_sentences, uni_block2idx):
        self.src_sentences = src_sentences  # Anchor (source)
        self.trg_sentences = trg_sentences  # Positive (target)
        self.neg_sentences = neg_sentences  # Negative
        self.src_block2idx = uni_block2idx
        self.trg_block2idx = uni_block2idx
        self.neg_block2idx = uni_block2idx
        
    def __len__(self):
        return len(self.trg_sentences)  # Usamos la longitud del conjunto positivo
    
    def __getitem__(self, idx):
        # Obtener las oraciones (anchor, positive, negative)
        src_sentence = self.src_sentences[idx]
        trg_sentence = self.trg_sentences[idx]
        neg_sentence = self.neg_sentences[idx]

        # Convertir cada oración en índices
        src_idxs = [self.src_block2idx.get(block, self.src_block2idx['<unk>']) for block in src_sentence.split()]
        trg_idxs = [self.trg_block2idx.get(block, self.trg_block2idx['<unk>']) for block in trg_sentence.split()]
        neg_idxs = [self.neg_block2idx.get(block, self.neg_block2idx['<unk>']) for block in neg_sentence.split()]

        # Retornar los tensores (anchor, positive, negative)
        return torch.tensor(src_idxs), torch.tensor(trg_idxs), torch.tensor(neg_idxs)

In [10]:
scripts_universal = scripts_train1 + scripts_train2 + scripts_train3


# ------------- Universal scripts ----------------------------
uni_block2idx, uni_idx2block = build_vocab(scripts_universal)
uni_vocab_size = len(uni_block2idx)

# ------------- Local context --------------------------------
src_block2idx, src_idx2block = build_vocab(scripts_train1)
#src_vocab_size = len(src_block2idx)
src_vocab_size = uni_vocab_size
trg_block2idx, trg_idx2block = build_vocab(scripts_train2)
#trg_vocab_size = len(trg_block2idx)
trg_vocab_size = uni_vocab_size
neg_block2idx, neg_idx2block = build_vocab(scripts_train3)
#neg_vocab_size = len(neg_block2idx)
neg_vocab_size = uni_vocab_size
print(trg_vocab_size)
print(neg_vocab_size)
print(uni_block2idx)
# ------------ Global context -----------------------------
src_block2idx_global, src_idx2block_global = build_vocab(scripts_train1_global)
src_vocab_size_global = len(src_block2idx_global)
trg_block2idx_global, trg_idx2block_global = build_vocab(scripts_train2_global)
trg_vocab_size_global = len(trg_block2idx_global)
neg_block2idx_global, neg_idx2block_global = build_vocab(scripts_train3_global)
neg_vocab_size_global = len(neg_block2idx_global)

142
142
{'<pad>': 0, '<unk>': 1, '<SCRIPT_END>': 2, '<sos>': 3, '<eos>': 4, 'looks_hide': 5, 'event_whenflagclicked': 6, 'event_whenbackdropswitchesto': 7, 'control_if': 8, 'looks_show': 9, 'control_wait': 10, 'control_forever': 11, 'data_changevariableby': 12, 'motion_gotoxy': 13, 'sensing_touchingobject': 14, 'sensing_touchingobjectmenu': 15, 'event_whenbroadcastreceived': 16, 'operator_equals': 17, 'data_setvariableto': 18, 'looks_switchcostumeto': 19, 'looks_costume': 20, 'looks_backdrops': 21, 'looks_switchbackdropto': 22, 'looks_sayforsecs': 23, 'event_whenkeypressed': 24, 'event_broadcast': 25, 'motion_glidesecstoxy': 26, 'event_whenthisspriteclicked': 27, 'operator_random': 28, 'sound_sounds_menu': 29, 'sound_play': 30, 'sensing_keypressed': 31, 'sensing_keyoptions': 32, 'control_stop': 33, 'control_wait_until': 34, 'operator_and': 35, 'data_hidevariable': 36, 'operator_gt': 37, 'sensing_touchingcolor': 38, 'operator_lt': 39, 'motion_movesteps': 40, 'motion_changeyby': 41, 'con

### CREATION OF DATASET FOR **TEST**

## TRAIN FUNCTIONS

In [11]:
def collate_fn_old(batch):
    trg_batch, src_batch, neg_batch = zip(*batch)
    trg_batch = [seq[:MAX_SEQ_LEN].clone().detach() for seq in trg_batch]
    src_batch = [seq[:MAX_SEQ_LEN].clone().detach() for seq in src_batch]
    neg_batch = [seq[:MAX_SEQ_LEN].clone().detach() for seq in neg_batch]
    trg_batch = torch.nn.utils.rnn.pad_sequence(trg_batch, batch_first=True, padding_value=0)
    src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    neg_batch = torch.nn.utils.rnn.pad_sequence(neg_batch, batch_first=True, padding_value=0)
    return src_batch, trg_batch, neg_batch


def collate_fn(batch):
    src_batch, pos_batch, neg_batch = zip(*batch)
    src_batch = [seq.clone().detach() for seq in src_batch]
    pos_batch = [seq.clone().detach() for seq in pos_batch]
    neg_batch = [seq.clone().detach() for seq in neg_batch]
    #print(trg_batch)
    # Hacemos el padding sin truncar primero
    #trg_batch = torch.nn.utils.rnn.pad_sequence(trg_batch, batch_first=True, padding_value=0)
    #src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    #neg_batch = torch.nn.utils.rnn.pad_sequence(neg_batch, batch_first=True, padding_value=0)
    
    # Ahora truncamos las secuencias después del padding a MAX_SEQ_LEN
    #trg_batch = trg_batch[:, :MAX_SEQ_LEN]
    #src_batch = src_batch[:, :MAX_SEQ_LEN]
    #neg_batch = neg_batch[:, :MAX_SEQ_LEN]

    src_batch = [torch.nn.functional.pad(seq[:MAX_SEQ_LEN], (0, MAX_SEQ_LEN - len(seq[:MAX_SEQ_LEN])), value=0) for seq in src_batch]
    pos_batch = [torch.nn.functional.pad(seq[:MAX_SEQ_LEN], (0, MAX_SEQ_LEN - len(seq[:MAX_SEQ_LEN])), value=0) for seq in pos_batch]
    neg_batch = [torch.nn.functional.pad(seq[:MAX_SEQ_LEN], (0, MAX_SEQ_LEN - len(seq[:MAX_SEQ_LEN])), value=0) for seq in neg_batch]

    src_batch = torch.stack(src_batch)    
    pos_batch = torch.stack(pos_batch)
    neg_batch = torch.stack(neg_batch)
    return src_batch, pos_batch, neg_batch

In [12]:
def train(model, dataloader, loss_function, optimiser, epochs):
    model.train()
    final_anchor_embeddings = []
    final_positive_embeddings = []
    final_negative_embeddings = []
    
    for epoch in range(epochs):
        total_loss = 0
        for i, (anchor_batch, positive_batch, negative_batch) in enumerate(dataloader):
            assert torch.max(anchor_batch) < src_vocab_size, f"Anchor out of bounds: {torch.max(anchor_batch)} >= {src_vocab_size}"
            assert torch.max(positive_batch) < trg_vocab_size, f"Positive out of bounds: {torch.max(positive_batch)} >= {trg_vocab_size}"
            assert torch.max(negative_batch) < trg_vocab_size, f"Negative out of bounds: {torch.max(negative_batch)} >= {trg_vocab_size}"
            
            anchor_batch = anchor_batch.to(device)
            positive_batch = positive_batch.to(device)
            negative_batch = negative_batch.to(device)
            
            # Zero grads
            optimiser.zero_grad()

            # Forward para anchor, positive y negative
            anchor_embeddings = model(anchor_batch)
            positive_embeddings = model(positive_batch)
            negative_embeddings = model(negative_batch)
            
            # Almacenar los embeddings solo en el último epoch
            if epoch == epochs - 1:
                final_anchor_embeddings.append(anchor_embeddings.cpu().detach())
                final_positive_embeddings.append(positive_embeddings.cpu().detach())
                final_negative_embeddings.append(negative_embeddings.cpu().detach())


            

            # Calcular la pérdida de Triplet
            loss = loss_function(anchor_embeddings, positive_embeddings, negative_embeddings)

            # Backpropagation y actualización de parámetros
            loss.backward()
            optimiser.step()

            total_loss += loss.item()
            
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch: {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')
    
    # Concatenar los embeddings del último epoch
    final_anchor_embeddings = torch.cat(final_anchor_embeddings, dim=0)
    final_positive_embeddings = torch.cat(final_positive_embeddings, dim=0)
    final_negative_embeddings = torch.cat(final_negative_embeddings, dim=0)

    return final_anchor_embeddings, final_positive_embeddings, final_negative_embeddings



def evaluate(model, dataloader):
    model.eval()  # Configura el modelo en modo de evaluación
    with torch.no_grad():
        all_anchor_embeddings = []
        all_positive_embeddings = []
        all_negative_embeddings = []
        
        # Recopila los embeddings para todos los lotes
        for anchor_batch, positive_batch, negative_batch in dataloader:
            anchor_batch = anchor_batch.to(device)
            positive_batch = positive_batch.to(device)
            negative_batch = negative_batch.to(device)
            
            # Obtener embeddings
            anchor_embeddings = model.forward(anchor_batch)
            positive_embeddings = model.forward(positive_batch)
            negative_embeddings = model.forward(negative_batch)
            
            # Añadir a las listas
            all_anchor_embeddings.append(anchor_embeddings)
            all_positive_embeddings.append(positive_embeddings)
            all_negative_embeddings.append(negative_embeddings)
        
        # Concatenar todos los embeddings
        all_anchor_embeddings = torch.cat(all_anchor_embeddings, dim=0)
        all_positive_embeddings = torch.cat(all_positive_embeddings, dim=0)
        all_negative_embeddings = torch.cat(all_negative_embeddings, dim=0)
    return anchor_embeddings, positive_embeddings, negative_embeddings

## EXEC TRAIN

### Local

In [14]:
BATCH_SIZE = 16
dataset_local = TripletDataset(scripts_train1, scripts_train2, scripts_train3, uni_block2idx)
dataloader_local = DataLoader(dataset_local, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
model = Transformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6,
                    input_vocab_size=uni_vocab_size,
                    max_len=MAX_SEQ_LEN, dropout=0.1)
model = model.to(device)
loss_function = triplet_loss = nn.TripletMarginLoss(margin=1.0)
optimiser = optim.Adam(model.parameters(), lr=0.0001)
# train local context
anchor_embeddings_local, positive_embeddings_local, negative_embeddings_local = train(model, dataloader_local, loss_function, optimiser, epochs = 5)


Epoch: 1/5, Loss: 0.8812
Epoch: 2/5, Loss: 0.6715
Epoch: 3/5, Loss: 0.5881
Epoch: 4/5, Loss: 0.5457
Epoch: 5/5, Loss: 0.5094


In [15]:
# save embeddings
model._all_embeddings = anchor_embeddings_local

# Save FULL model trained (arch and weights)
#torch.save(model, 'action_local_scratch_triplet.pth')

### Global

In [13]:

BATCH_SIZE = 16
dataset_global = TripletDataset(scripts_train1_global, scripts_train2_global, scripts_train3_global, uni_block2idx)
dataloader_global = DataLoader(dataset_global, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
model_global = Transformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6,
                    input_vocab_size=uni_vocab_size,
                    max_len=MAX_SEQ_LEN, dropout=0.1)
model_global = model_global.to(device)
loss_function_global = triplet_loss = nn.TripletMarginLoss(margin=1.0)
optimiser_global = optim.Adam(model.parameters(), lr=0.0001)
# train global context
anchor_embeddings_global, positive_embeddings_global, negative_embeddings_global = train(model_global, dataloader_global, loss_function_global, optimiser_global, epochs = 50)


Epoch: 1/50, Loss: 2.3477
Epoch: 2/50, Loss: 2.3471
Epoch: 3/50, Loss: 2.3195
Epoch: 4/50, Loss: 2.3222
Epoch: 5/50, Loss: 2.3318
Epoch: 6/50, Loss: 2.3139
Epoch: 7/50, Loss: 2.3333
Epoch: 8/50, Loss: 2.3268
Epoch: 9/50, Loss: 2.3290
Epoch: 10/50, Loss: 2.3317
Epoch: 11/50, Loss: 2.3421
Epoch: 12/50, Loss: 2.3593
Epoch: 13/50, Loss: 2.3235
Epoch: 14/50, Loss: 2.3136
Epoch: 15/50, Loss: 2.3455
Epoch: 16/50, Loss: 2.3241
Epoch: 17/50, Loss: 2.3036
Epoch: 18/50, Loss: 2.3496
Epoch: 19/50, Loss: 2.3276
Epoch: 20/50, Loss: 2.3173
Epoch: 21/50, Loss: 2.3574
Epoch: 22/50, Loss: 2.3418
Epoch: 23/50, Loss: 2.2955
Epoch: 24/50, Loss: 2.3250
Epoch: 25/50, Loss: 2.3294
Epoch: 26/50, Loss: 2.3201
Epoch: 27/50, Loss: 2.3065
Epoch: 28/50, Loss: 2.3050
Epoch: 29/50, Loss: 2.3491
Epoch: 30/50, Loss: 2.3426
Epoch: 31/50, Loss: 2.3462
Epoch: 32/50, Loss: 2.3009
Epoch: 33/50, Loss: 2.3256
Epoch: 34/50, Loss: 2.3211
Epoch: 35/50, Loss: 2.3370
Epoch: 36/50, Loss: 2.3190
Epoch: 37/50, Loss: 2.3282
Epoch: 38/

In [14]:
# save embeddings
model_global._all_embeddings = anchor_embeddings_global

# Save FULL model trained (arch and weights)
torch.save(model_global, 'action_global_sprite_scratch_triplet.pth')

# GRAPH EMBEDDINGS

In [None]:
# Concat local and global embeddings

anchor_embeddings = torch.cat((anchor_embeddings_local, anchor_embeddings_global), dim=-1)
positive_embeddings = torch.cat((positive_embeddings_local, positive_embeddings_global), dim=-1)
negative_embeddings = torch.cat((negative_embeddings_local, negative_embeddings_global), dim=-1)