In [1]:
import json
import re
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score
import nltk
nltk.download('punkt_tab')

from tqdm import tqdm


[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
import random

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
with open("/kaggle/input/diplomacy/2020_acl_diplomacy-master/data/test.jsonl", "r", encoding="utf-8") as file:
    test_data = [json.loads(line) for line in file]

with open("/kaggle/input/diplomacy/2020_acl_diplomacy-master/data/train.jsonl", "r", encoding="utf-8") as file:
    train_data = [json.loads(line) for line in file]

with open("/kaggle/input/diplomacy/2020_acl_diplomacy-master/data/validation.jsonl", "r", encoding="utf-8") as file:
    val_data = [json.loads(line) for line in file]



In [6]:
import re

def preprocess(sentence):
    sentence = sentence.lower()
    sentence = re.sub(r"[^a-zA-Z0-9 ]", "", sentence)  # remove punctuation (Glove)
    sentence = re.sub(r"\s+", " ", sentence).strip()
    return sentence

def prep_data_context(data, is_sender, is_train, tokenizer):
    final_data = []
    chunk_size = 230

    for data_points in data:
        messages = data_points["messages"]
        labels = data_points["sender_labels"] if is_sender else data_points["receiver_labels"]
        game_score_deltas = data_points["game_score_delta"]

        # For training, split into 300-message chunks
        chunks = (
            [messages[i:i+chunk_size] for i in range(0, len(messages), chunk_size)]
            if is_train else [messages]
        )

        for chunk_index, message_chunk in enumerate(chunks):
            sub = []
            for i, message in enumerate(message_chunk):
                index = i + chunk_index * chunk_size
                if index >= len(labels) or labels[index] == 'NOANNOTATION':
                    continue

                msg = preprocess(message)
                tokenized = tokenizer(msg, truncation=True,max_length=200)

                if len(tokenized["input_ids"]) == 0:
                    continue

                sub.append({
                    "message": tokenized["input_ids"],
                    "label": labels[index],
                    "game_score_delta": int(game_score_deltas[index])
                })

            if sub:
                final_data.append(sub)

    return final_data


In [7]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
val=prep_data_context(val_data,1,False,tokenizer)
train=prep_data_context(train_data,1,True,tokenizer)
test=prep_data_context(test_data,1,False,tokenizer)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [8]:
tokens = []
for sub in train:
  for data_p in sub:
    for word in data_p["message"]:
        tokens.append(word)
tokens=list(set(tokens))

vocab = {token:idx+2  for idx , token in enumerate(tokens)}
vocab["<PAD>"]=0
vocab["<UNK>"]=1


In [9]:

class Deception_dataset_context(Dataset):
    def __init__(self, data ):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      data_p = self.data[idx]
      msg_ids = []
      
      for sub in data_p:
        
          
          msg_ids.append(torch.tensor(sub["message"], dtype=torch.long))
    
          # pritn(msg_ids[-1].shape)
      try:
          return {
              "messages": msg_ids,
              "labels": torch.tensor([i['label'] for i in data_p], dtype=torch.long),
              "game_score_delta": torch.tensor([i["game_score_delta"] for i in data_p], dtype=torch.float)
          }
      except Exception as e:
          print("issue: ", e)
          return



In [10]:
def collate_fn_context(batch):
    messages = []
    labels = []
    lengths = []
    game_score_deltas = []
    num_messages = []
  
    for i in batch:
      # print(len(i['messages']))
      # print((i['messages'][0]).shape)
      messages.extend(i['messages'])
  
      lengths.extend([len(j) for j in i['messages']])
      labels.extend(i['labels'])
      num_messages.append(len(i['messages']))
      game_score_deltas.extend(i['game_score_delta'])
    # print(len(messages))
    padded_messages = pad_sequence(messages,batch_first=True,padding_value=0)


    return {
        "messages": (padded_messages),
        "lengths": torch.tensor(lengths, dtype=torch.long),
        "labels": torch.tensor(labels),
        "num_messages":num_messages,
        "deltas":torch.tensor(game_score_deltas)
    }


In [11]:
import torch
import torch.nn as nn
%pip install git+https://github.com/geoopt/geoopt.git
import geoopt

class HyperbolicGRUCell(nn.Module):
    def __init__(self,inp_dim,hidden_size,manifold):
        super().__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hidden_size
        self.manifold = manifold

        self.wz = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.hid_dim,self.hid_dim+self.inp_dim,dtype=torch.float64)))
        self.wr = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.hid_dim,self.hid_dim+self.inp_dim,dtype=torch.float64)))
        self.w = nn.Parameter(nn.init.xavier_normal_(torch.empty(self.hid_dim,self.hid_dim+self.inp_dim,dtype=torch.float64)))

    def forward(self,h,x):
        # h and x are of shape bs x hid_dim, bs x inp_dim
        # this is just one pass, not for a sequence
        # print("h shape: ",h.shape)
        # print("x shape: ",x.shape)
        h_x = torch.cat((h,x),dim=-1)
        z = torch.sigmoid(self.manifold.logmap0(self.manifold.projx(self.manifold.mobius_matvec(self.wz,h_x))))
        r = torch.sigmoid(self.manifold.logmap0(self.manifold.projx(self.manifold.mobius_matvec(self.wr,h_x))))
        h_x_to_tilde = self.manifold.projx(self.manifold.expmap0(torch.cat((self.manifold.logmap0(self.manifold.mobius_pointwise_mul(r,h)),self.manifold.logmap0(x)),dim=-1)))
        h_tilde = self.manifold.expmap0(torch.tanh(self.manifold.logmap0(self.manifold.projx(self.manifold.mobius_matvec(self.w,h_x_to_tilde)))))
        h = self.manifold.projx(self.manifold.mobius_add(self.manifold.mobius_pointwise_mul(1-z,h),self.manifold.mobius_pointwise_mul(z,h_tilde)))
        return h



class HyperbolicGRULayer(nn.Module):
    def __init__(self,inp_dim,hidden_size,manifold,dirs):
        super().__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hidden_size
        self.manifold = manifold
        self.dirs = dirs
        self.gru_cell = HyperbolicGRUCell(self.inp_dim,self.hid_dim,self.manifold)
        self.h_init = geoopt.tensor.ManifoldParameter(self.manifold.projx(self.manifold.expmap0(torch.zeros(self.hid_dim,dtype=torch.float64))),manifold=self.manifold)

    def forward(self,seq,lengths):
        # seq is bs x max_seq_len x inp_dim
        # h_init is bs x inp_dim
        # lengths is bs
        max_len = seq.shape[1]
        bs = seq.shape[0]
        hid = self.h_init.expand(bs,-1)
        outs_left = []

        # left to right
        for i in range(max_len):
            inp = seq[:,i,:]
            out = self.gru_cell(hid,inp) # bs x hid_dim
            # but not all outputs will be valid. since some sequences may be padded. so i need to ignore those somehow
            # mask = torch.tensor([1 if len>i else 0 for len in lengths]).unsqueeze(-1).to(seq.device) # error if i dont unsqueeze. inefficient
            mask = (lengths>i).float().unsqueeze(-1).to(seq.device)
            hid = self.manifold.projx(self.manifold.mobius_add(self.manifold.mobius_pointwise_mul(mask,out),self.manifold.mobius_pointwise_mul(1-mask,hid)))
            outs_left.append(hid)

        if(self.dirs==2):
            hid = self.h_init.expand(bs, -1)
            seq = torch.flip(seq,dims=[1])
            outs_right = []
            # right to left
            for i in range(max_len):
                inp = seq[:,i,:]
                out = self.gru_cell(hid,inp) # bs x hid_dim
                # but not all outputs will be valid. since some sequences may be padded. so i need to ignore those somehow
                # mask = torch.tensor([1 if len>i else 0 for len in lengths]).unsqueeze(-1).to(seq.device) # error if i dont unsqueeze. inefficient
                mask = (lengths>i).float().unsqueeze(-1).to(seq.device)
                hid = self.manifold.projx(self.manifold.mobius_add(self.manifold.mobius_pointwise_mul(mask,out),self.manifold.mobius_pointwise_mul(1-mask,hid)))
                outs_right.append(hid)

        outs = torch.stack(outs_left,dim=1)
        if(self.dirs==2):
            outs_right.reverse()
            outs_right = torch.stack(outs_right,dim=1)
            outs = self.manifold.projx(self.manifold.expmap0(torch.cat((self.manifold.logmap0(outs),self.manifold.logmap0(outs_right)),dim=-1)))

        return outs, outs[:,-1,:]

class HyperbolicGRU(nn.Module):
    def __init__(self,inp_dim,hidden_size,manifold,num_layers,dirs):
        super().__init__()
        self.inp_dim = inp_dim
        self.hid_dim = hidden_size
        self.manifold = manifold
        self.num_layers = num_layers
        self.dirs = dirs
        self.layers = nn.ModuleList()
        for i in range(self.num_layers):
            layer_inp_dim = self.inp_dim if i == 0 else (self.hid_dim if self.dirs==1 else 2*self.hid_dim)
            self.layers.append(HyperbolicGRULayer(layer_inp_dim, self.hid_dim, self.manifold, self.dirs))

    def forward(self,seq,lengths):
        inp = seq

        for idx,layer in enumerate(self.layers):
            # print("input to layer ",idx," : inp:",inp.shape," , hid: ",hid.shape)
            inp,hid = layer(inp,lengths)

        return hid



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting git+https://github.com/geoopt/geoopt.git
  Cloning https://github.com/geoopt/geoopt.git to /tmp/pip-req-build-wygbs_d5
  Running command git clone --filter=blob:none --quiet https://github.com/geoopt/geoopt.git /tmp/pip-req-build-wygbs_d5
  Resolved https://github.com/geoopt/geoopt.git to commit eaadc68fcae361778edf078b503ed79e4497c071
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.1->geoopt==0.5.1)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.1->geoopt==0.5.1)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.1->geoopt==0.5.1)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64

In [13]:
class HyLinear(nn.Module):
    def __init__(self,input_dim,output_dim,act,manifold,bias=None):
        super(HyLinear, self).__init__()
        self.inp_dim = input_dim
        self.out_dim = output_dim
        self.manifold = manifold
        self.activation = act
        self.weight_matrix = nn.Parameter(torch.randn((self.out_dim,self.inp_dim),dtype = torch.float64))
        if(bias is not None):
            self.bias = geoopt.tensor.ManifoldParameter(self.manifold.projx(self.manifold.expmap0(torch.zeros(self.out_dim,dtype=torch.float64))),manifold=self.manifold)
        else:
            self.bias=None

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight_matrix, gain=math.sqrt(2))
        if(self.bias is not None):
            torch.nn.init.constant_(self.bias, 0)


    def forward(self,x):
        # x = x.double()
        op = self.manifold.mobius_matvec(self.weight_matrix,x)
        op = self.manifold.projx(op)
        if(self.bias is not None):
            op = self.manifold.mobius_add(op,self.bias)
            op = self.manifold.projx(op)
        if(self.activation is not None):
            op = self.manifold.projx(self.manifold.expmap0(self.activation(self.manifold.logmap0(op))))
        return op

In [17]:
import random
import torch_geometric
from torch_geometric.data import Batch, Data

class ContextLSTM(nn.Module):
  def __init__(self,embed_model, embedding_dim, hidden_size_message, gat_dim, num_classes,  manifold, num_layers, dirs,context_win_graph=5):
    super(ContextLSTM,self).__init__()
    seed = 100
    self.context_win_graph=context_win_graph
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    self.embedding_model = embed_model
    self.embedding_size = embedding_dim

    self.manifold = manifold
    self.num_layers = num_layers
    self.dirs = dirs
    
    for param in self.embedding_model.parameters():
                param.requires_grad = False # hyperparameter

   
    # can add hylinear here
    self.gru_message =  HyperbolicGRU(embedding_dim,hidden_size,self.manifold,self.num_layers,self.dirs)
      
    # self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim//2,heads=4)
    self.gat1 = torch_geometric.nn.conv.GATConv(in_channels=hidden_size_message*2,out_channels=gat_dim*2,heads=4,  concat=False).to(torch.float64)
      
    self.gat2 = torch_geometric.nn.conv.GATConv(in_channels=gat_dim*2,out_channels=gat_dim,heads=2 , concat=False).to(torch.float64)
    self.relu = torch.nn.ReLU()
    self.tanh = torch.nn.Tanh()
    self.sigmoid = torch.nn.Sigmoid()  
    self.gate1 = nn.Linear(hidden_size_message * 2 + gat_dim,hidden_size_message * 2,dtype=torch.float64)
    self.gate2 = nn.Linear(hidden_size_message * 2 + gat_dim,gat_dim,dtype=torch.float64)
    self.hidden_size_message = hidden_size_message
    self.gat_dim = gat_dim


    self.attn_dim = 256  # Should match your hidden size requirements
    self.attn = nn.MultiheadAttention(
        embed_dim=self.attn_dim, 
        num_heads=8,
        kdim=self.hidden_size_message*2 + self.gat_dim,
        vdim=self.hidden_size_message*2 + self.gat_dim,
        dtype=torch.float64
    )
    self.fusion_proj = nn.Linear(self.hidden_size_message*2 + self.gat_dim, self.attn_dim,dtype=torch.float64)

    self.fc = nn.Linear(hidden_size_message * 2 + gat_dim+1, num_classes).to(dtype=torch.float64)

    # self.fc = nn.Linear(self.attn_dim, num_classes,dtype=torch.float64)

  def attn_fusion(self, x):
      # Project combined features
    projected = self.fusion_proj(x)  # [batch, attn_dim]
    
    # Attention expects [seq_len, batch, features]
    projected = projected.unsqueeze(0)  # [1, batch, attn_dim]
    
    # Self-attention with learned relationships
    attn_output, _ = self.attn(
        projected,  # Query
        x.unsqueeze(0),  # Key (original features)
        x.unsqueeze(0)   # Value (original features)
    )
    
    return attn_output.squeeze(0)
      
  def fusion(self,x):
      lstm_emb = x[:,:2*self.hidden_size_message]
      gat_emb = x[:,2*self.hidden_size_message:]
      
      lstm_emb_tanh = self.tanh(lstm_emb) 
      gat_emb_tanh = self.tanh(gat_emb)
      # print('lstm emb: ',lstm_emb.shape)
      # print("gat emb shape: ",gat_emb.shape)
      mixed = torch.cat((lstm_emb,gat_emb),dim=1)
      mixed_sigmoid_left = self.sigmoid(self.gate1(mixed))
      mixed_sigmoid_right = self.sigmoid(self.gate2(mixed))

      final_emb = torch.cat((mixed_sigmoid_left*lstm_emb_tanh,mixed_sigmoid_right*gat_emb_tanh),dim=1)
      return final_emb

  def get_message_emb(self,input_ids,lengths):
    embedded = self.manifold.expmap0(self.embedding_model(input_ids).last_hidden_state)
    last_hidden = self.gru_message(embedded,lengths)
    return last_hidden

  

  def forward(self, input_ids, num_messages, lengths , scores):
    total_messages = sum(num_messages)
    # print(input_ids.shape)
    inputs_to_msg_encoder = input_ids
    encoded_messages = self.manifold.logmap0(self.get_message_emb(inputs_to_msg_encoder,lengths))
    input_to_convo_encoder = torch.split(encoded_messages,num_messages)
    # print(input_to_convo_encoder[0].shape)

    tot_data = []

    # for i in range(len(input_to_convo_encoder)):
    #     msg_embs = input_to_convo_encoder[i] # num_msgs x emb_dim
    #     n_nodes = msg_embs.shape[0]
    #     edge_ind = torch.combinations(torch.arange(n_nodes), r=2).T
    #     edge_ind = torch.cat([edge_ind, edge_ind.flip(0)], dim=1).to(input_ids.device)
    
    #     tot_data.append(Data(x=msg_embs,edge_index=edge_ind))


    threshold = 0.7  # Can tune this
    
    for i, msg_embs in enumerate(input_to_convo_encoder):
        n_nodes = msg_embs.shape[0]
        sources, targets = [], []
    
        norm_embs = torch.nn.functional.normalize(msg_embs, p=2, dim=1)
    
        sim_matrix = torch.mm(norm_embs, norm_embs.T)  
    
        for src in range(n_nodes):
            for tgt in range(n_nodes):
                if src > tgt and sim_matrix[src, tgt] > threshold:
                    # print(src,tgt)
                    sources.append(src)
                    targets.append(tgt)
    
    
        edge_index = torch.tensor([sources, targets], device=input_ids.device, dtype=torch.long)
    
        tot_data.append(Data(x=msg_embs, edge_index=edge_index))

    tot_data = Batch.from_data_list(tot_data)
    # print("tot_data dtype: ",tot_data.x.dtype)
    from_gat1 = self.relu(self.gat1(tot_data.x,tot_data.edge_index))
    
    # print("from gat 1 shape: ",from_gat1.shape)

    from_gat2 = self.gat2(from_gat1,tot_data.edge_index)
    # print("enc msg shape: ",encoded_messages.shape)
    # print("from_gat2 shap: ",from_gat2.shape)
      
    input_for_fusion = [] 
    prev = 0
    # print("enc msg shape: ",encoded_messages.shape)
    for i, msg_emb in enumerate(encoded_messages):
        # print("i: ",i)
        inp = torch.cat((msg_emb,from_gat2[i]),dim=0)
        input_for_fusion.append(inp)
    # print("shape before fusion: ",torch.stack(input_for_fusion).shape)
    # final_combined = self.fusion(torch.stack(input_for_fusion))
    # print("final combined shape: ",final_combined.shape)
    final_input = torch.cat((torch.stack(input_for_fusion),  scores.unsqueeze(1)), dim=1)
    # print(final_input.shape)
    # logits = self.fc(torch.stack(input_for_fusion))
    logits = self.fc(((final_input)))
      
    # print("logits shape: ",logits.shape)
    return logits



In [18]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
from transformers import BertModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


embedding_dim = 768
hidden_size = 128
num_classes = 2
gat_dim = 50

seed = 100

torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


manifold = geoopt.manifolds.PoincareBall(c=1.0)

model_embed = BertModel.from_pretrained('bert-base-uncased').to(device)

model = ContextLSTM(model_embed,  embedding_dim, hidden_size,gat_dim, num_classes,manifold , 1,2,5)
model = model.to(device)

class_weights = torch.tensor([1.0 / 0.10, 1.0 / 0.90], dtype=torch.float64)
class_weights = class_weights.to(device)


# model = ContextLSTM(vocab, glove_file, embedding_dim, hidden_size//2,hidden_size, num_classes)
# model = model.to(device)
optimizer =  geoopt.optim.RiemannianAdam(model.parameters(),lr=1e-3)
loss = nn.CrossEntropyLoss(weight=class_weights)

train_dataset = Deception_dataset_context(train)
val_dataset = Deception_dataset_context(val)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn_context)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_context)


test_dataset = Deception_dataset_context(test)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_context)

print("hey")
for epoch in range(15):
    
    model.train()
    train_loss = 0
    train_preds, train_labels = [], []
    for batch in tqdm(train_dataloader):
        messages = batch["messages"].to(device)
        lengths = batch["lengths"]
        labels = batch["labels"].to(device)
        num_messages = batch['num_messages']
        scores = batch['deltas'].to(device)

        optimizer.zero_grad()
        logits = model(messages, num_messages,lengths,scores)
        loss_ = loss(logits, labels)
        loss_.backward()
        optimizer.step()

        train_loss += loss_.item()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        train_preds.extend(preds)
        train_labels.extend(labels.cpu().numpy())


    train_loss /= len(train_dataloader)
    train_f1 = f1_score(train_labels, train_preds, average='macro')


    model.eval()
    val_loss = 0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_dataloader:
            messages = batch["messages"].to(device)
            lengths = batch["lengths"]
            labels = batch["labels"].to(device)
            num_messages = batch['num_messages']
            scores = batch['deltas'].to(device)
            logits = model(messages, num_messages,lengths,scores)
            loss_ = loss(logits, labels)
            val_loss += loss_.item()
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            val_preds.extend(preds)
            val_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_dataloader)
    val_f1 = f1_score(val_labels, val_preds, average='macro')

    model.eval()
    test_preds, test_labels = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            messages = batch["messages"].to(device)
            lengths = batch["lengths"]
            labels = batch["labels"].to(device)
            num_messages = batch['num_messages']
            scores = batch['deltas'].to(device)
            logits = model(messages, num_messages,lengths,scores)
    
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            test_preds.extend(preds)
            test_labels.extend(labels.cpu().numpy())
    
    
    test_f1 = f1_score(test_labels, test_preds, average='macro')

    torch.save(model.state_dict(),f"epoch_{epoch+1}.pth")
    print(test_f1)


    print(f"Epoch {epoch+1}:  Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f} Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}")

hey


100%|██████████| 199/199 [19:28<00:00,  5.87s/it]


0.48692930888245445
Epoch 1:  Train Loss: 0.5925, Train F1: 0.4899 Val Loss: 0.5789, Val F1: 0.4888


  1%|          | 2/199 [00:10<17:16,  5.26s/it]

In [None]:
 # torch.save(model.state_dict(),f"gat+lstm+glove_best.pth")

In [None]:
# test_dataset = Deception_dataset_context(test, vocab)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn_context)


# model.eval()
# test_preds, test_labels = [], []
# with torch.no_grad():
#     for batch in test_dataloader:
#         messages = batch["messages"].to(device)
#         lengths = batch["lengths"]
#         labels = batch["labels"].to(device)
#         num_messages = batch['num_messages']
#         scores = batch['deltas'].to(device)
#         logits = model(messages, num_messages,lengths,scores)

#         preds = torch.argmax(logits, dim=1).cpu().numpy()
#         test_preds.extend(preds)
#         test_labels.extend(labels.cpu().numpy())


# test_f1 = f1_score(test_labels, test_preds, average='macro')

# print(test_f1)