# Inference Demo BiLSTM

In [None]:
# Inference Demo Code - Solution B - BILSTM

# Please install the following libraries before running the code
# !pip install torch
# !pip install transformers
# !pip install pandas

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn as nn
from transformers import BertTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import re
import json
import os

MAX_LEN = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model classification
class BERT_LSTMClassifier(nn.Module):
  def __init__(self, hidden_size=128, num_classes=2, dropout=0.3, num_layers=2, pretrained_model='bert-base-uncased'):
    super(BERT_LSTMClassifier, self).__init__()
    self.bert_embeddings = AutoModel.from_pretrained(pretrained_model).embeddings.word_embeddings

    self.lstm = nn.LSTM(
        input_size=768,
        hidden_size=hidden_size,
        num_layers=num_layers,
        batch_first=True,
        bidirectional=True
    )
    self.dropout = nn.Dropout(dropout)
    self.fc = nn.Linear(hidden_size * 2, num_classes)

  def forward(self, input_ids, attention_mask):
    embeddings = self.bert_embeddings(input_ids)
    lstm_out, _ = self.lstm(embeddings)
    pooled_output = torch.mean(lstm_out, dim=1)
    out = self.dropout(pooled_output)
    logits = self.fc(out)
    return logits

# load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# ED dataset
class EDDataset(Dataset):
  def __init__(self, texts, labels=None):
    self.texts = texts
    self.labels = labels
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, index):
    text = self.texts[index]
    encoding = self.tokenizer(text, max_length=MAX_LEN, padding='max_length', truncation=True, return_tensors='pt')
    item = {
        'input_ids': encoding['input_ids'].squeeze(0),
        'attention_mask': encoding['attention_mask'].squeeze(0)
    }
    if self.labels is not None:
      item['labels'] = torch.tensor(self.labels[index], dtype=torch.long)
    return item

# text preprocessing
def clean_text(text):
  text = re.sub(r"can't\b", "cannot", text)
  text = re.sub(r"won't\b", "will not", text)
  text = re.sub(r"n't\b", " not", text)
  text = re.sub(r"'re\b", " are", text)
  text = re.sub(r"'m\b", " am", text)
  text = re.sub(r"'ve\b", " have", text)
  text = re.sub(r"'ll\b", " will", text)
  text = re.sub(r"'d\b", " would", text)
  text = re.sub(r"\b(he|she|it|that|what|who|there|where|why|when)'s\b", r"\1 is", text, flags=re.IGNORECASE)
  text = re.sub(r'http\S+|www\S+|https\S+', '[URL]', text)
  text = re.sub(r"\s+", " ", text).strip()
  return text

Mounted at /content/drive


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

# Prediction Function

In [None]:
# predict function
def predict(model, dataloader, device):
  model.eval()
  all_logits = []
  for batch in dataloader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    # Turn off gradient tracking to speed up prediction and save memory
    with torch.no_grad():
      logits = model(input_ids, attention_mask)
    all_logits.append(logits)
  return torch.cat(all_logits, dim=0)

# load best model function
def load_best_model(model_path, param_path, device):
  with open(param_path) as f:
    params = json.load(f)
  model = BERT_LSTMClassifier(
      hidden_size=params['lstm_hidden_size'],
      dropout=params['dropout'],
      num_layers=params['num_lstm_layers'],
      num_classes=2
  ).to(device)
  model.load_state_dict(torch.load(model_path, map_location=device))
  model.eval()
  return model


# Run Prediction

In [None]:
# main
if __name__ == '__main__':
  model_dir = '/content/drive/My Drive/nlu-lab/lstm_models'
  model_path = os.path.join(model_dir, 'best_model_bilstm.pt')
  best_param_path = os.path.join(model_dir, 'best_params_bilstm.json')

  # read best parameters
  with open(best_param_path, 'r') as f:
    params = json.load(f)

  # load model
  model = BERT_LSTMClassifier(
      hidden_size=params['lstm_hidden_size'],
      dropout=params['dropout'],
      num_layers=params['num_lstm_layers'],
      num_classes=2
  ).to(device)

  model.load_state_dict(torch.load(model_path, map_location=device))
  model.eval()

  # test dataset
  test_file_path = '/content/drive/My Drive/nlu-lab/test.csv'
  test_df = pd.read_csv(test_file_path)

  test_df['combined_text'] = (test_df['Claim'] + " " + test_df['Evidence']).apply(clean_text)
  test_dataset = EDDataset(test_df['combined_text'].tolist())
  test_dataloader = DataLoader(test_dataset, batch_size=params['batch_size'], shuffle=False)

  # test prediction
  test_logits = predict(model, test_dataloader, device)
  test_pred = torch.argmax(test_logits, dim=1)

  # save results
  output_file = '/content/drive/My Drive/nlu-lab/lstm_models/Group_21_B.csv'
  results = pd.DataFrame({'prediction': test_pred.cpu().numpy()})
  results.to_csv(output_file, index=False)
  print(f"Results saved to: {output_file}")
  print(results.head())


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Results saved to: /content/drive/My Drive/nlu-lab/lstm_models/Group_21_B.csv
   prediction
0           1
1           1
2           1
3           0
4           0
