# Importing modules

In [1]:


import numpy as np
import math


import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader, TensorDataset, Subset

import csv

import multiprocessing as mp
import os
import sys
import copy
import random
import gc
import time
from tqdm import tqdm
from collections import defaultdict

import itertools

import dill

import warnings
warnings.filterwarnings('ignore')

from datasets import load_dataset

In [2]:
from model import *

# Checking cuda

In [3]:
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
    device_index = 0
    device = torch.device(f"cuda:{device_index}")
    print('using cuda...')
else:
    device = torch.device("cpu")
    print('using cpu...')

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

Device 0: NVIDIA GeForce RTX 4090
using cuda...


In [4]:
# Load the IMDB dataset
dataset = load_dataset('squad')


In [5]:
questions = dataset['train']['question'][:1000]
answers   = dataset['train']['answers'][:1000]
qa_pairs  = []
for question, answer in zip(questions, answers):
    qa_pairs.append(f"Q: {question}\nA: {answer['text'][0]}\n")

# Show the first 3 question-answer pairs
for qa in qa_pairs[:3]:
    print(qa)
    print("-" * 50)

Q: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
A: Saint Bernadette Soubirous

--------------------------------------------------
Q: What is in front of the Notre Dame Main Building?
A: a copper statue of Christ

--------------------------------------------------
Q: The Basilica of the Sacred heart at Notre Dame is beside to which structure?
A: the Main Building

--------------------------------------------------


In [6]:
import torch
from transformers import BertTokenizer, BertModel
import numpy as np

max_length = 100

# 初始化 BERT tokenizer 和 vectorizer
tokenizer  = BertTokenizer.from_pretrained('bert-base-uncased')
vectorizer = BertModel.from_pretrained('bert-base-uncased')

# 範例句子
# sentences = [
#     "Hello, how are you? I am fine",
#     "BERT is a powerful model.  I am fine",
#     "Let's use transformers for NLP.  I am fine"
#     "You are really a muta flicker.  I am fine"
# ]
sentences = qa_pairs


# Step 1: Tokenize the sentences
tokenized_sentences = tokenizer(sentences, padding='max_length', max_length=max_length, truncation=True, return_tensors="pt")


# Step 2: Vectorize the sentences
input_ids           = tokenized_sentences['input_ids']
attention_masks     = tokenized_sentences['attention_mask']
with torch.no_grad(): 
    input_vectors   = vectorizer(input_ids).last_hidden_state * ( attention_masks.unsqueeze(2) )




def create_dataset(input_vectors, attention_masks):

    final_input  = []
    final_label  = []
    final_mask_1 = []
    final_mask_2 = []

    for i in range(input_vectors.size(0)): 
        for j in range(input_vectors.size(1) - 1):
            
            if attention_masks[i][j + 1] != 0:

                factored_mask     = torch.zeros_like(attention_masks[i])
                factored_mask[:j] = 1
                input             = input_vectors[i] * factored_mask.unsqueeze(1)
                final_input.append(input)

                label = input_vectors[i][j+1]
                final_label.append(label)

                mask_2 = factored_mask.unsqueeze(1) * factored_mask.unsqueeze(0)
                mask_2 = mask_2.unsqueeze(0)
                mask_1 = (mask_2 -1) * 1e20
                final_mask_1.append(mask_1)
                final_mask_2.append(mask_2)

    final_input  = torch.stack(final_input , dim=0)
    final_label  = torch.stack(final_label , dim=0)
    final_mask_1 = torch.stack(final_mask_1, dim=0)
    final_mask_2 = torch.stack(final_mask_2, dim=0)

    return final_input, final_label, final_mask_1, final_mask_2





# 生成資料集
final_input, final_label, final_mask_1, final_mask_2 = create_dataset(input_vectors, attention_masks)






We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [7]:



dataset    = TensorDataset(final_input, final_label, final_mask_1, final_mask_2)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


In [8]:

sequence_size =  max_length             
feature_size = 768          
num_layers = 3                      
num_heads = 4                
hidden_activation = 'tanh'
output_activation = 'sigmoid'
initializer = "random_normal"
optimizer = 'sgd'
loss = 'mean_squared_error'
bias = False
drop_rate = 0.0
alpha = 0.000001                        


# Initialize the model
model = build_model(sequence_size,
                    feature_size,
                    num_layers,
                    num_heads,
                    hidden_activation,
                    output_activation,
                    initializer,
                    optimizer,
                    loss,
                    bias,
                    drop_rate,
                    alpha)


model = model.to(device)


In [18]:
num_epochs = 100  

for epoch in range(num_epochs):
    
    model.train()  
    
    running_loss = 0.0

    for batch_idx, (input, label, mask_1, mask_2) in tqdm(enumerate(dataloader), total=len(dataloader), desc="Training Progress", ncols=100, unit="batch"):

        input  = input.to(device)
        label  = label.to(device)
        mask_1 = mask_1.to(device)
        mask_2 = mask_2.to(device)
        
        optimizer = model.optimizer
        optimizer.zero_grad()

        loss_function   = model.loss_function
        output          = model(input, (mask_1, mask_2))
        loss            = loss_function(output, label)
        loss.backward()     # get grad

        optimizer.step()    # update params

        running_loss += loss.item()
        

    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] finished. Average Loss: {epoch_loss:.4f}")


    


Training Progress:  44%|███████████████▊                    | 9003/20522 [01:59<02:36, 73.81batch/s]

In [19]:
import torch
from transformers import BertTokenizer, BertModel
import numpy as np

# 範例句子
sentence = "who are you"

for _ in range(max_length):

    # Step 1: Tokenize the sentences
    tokenized_sentence = tokenizer(sentence, padding='max_length', max_length=max_length, truncation=True, return_tensors="pt")

    # Step 2: Vectorize the sentences
    input_id           = tokenized_sentence['input_ids']
    attention_mask     = tokenized_sentence['attention_mask']
    with torch.no_grad(): 
        input_vector   = vectorizer(input_id).last_hidden_state 

    input_vector = input_vector.to(device)
    mask_2 = attention_mask[0].unsqueeze(1) * attention_mask[0].unsqueeze(0)
    mask_2 = mask_2.unsqueeze(0).unsqueeze(0)
    mask_2 = mask_2.to(device)
    mask_1 = (mask_2 -1) * 1e20
    mask_1 = mask_1.to(device)

    model.eval()  
    output          = model(input_vector, (mask_1, mask_2))

    vocab_embeddings       = vectorizer.get_input_embeddings().weight .to(device)
    cos_sim                = F.cosine_similarity(output, vocab_embeddings, dim=1)
    most_similar_token_idx = torch.argmax(cos_sim).item()

    word                   = tokenizer.convert_ids_to_tokens(most_similar_token_idx)
    print(word)

    
    
    # vocab = tokenizer.get_vocab()
    # while most_similar_token_idx in vocab and vocab[most_similar_token_idx].startswith('[unused'):
    #     most_similar_token_idx = torch.argmax(cos_sim).item()  # Recompute the most similar token

    # from transformers import GPT2LMHeadModel, GPT2Tokenizer
    # tokenizer_ = GPT2Tokenizer.from_pretrained('gpt2')
    # model_     = GPT2LMHeadModel.from_pretrained('gpt2')
    # output     = model_.generate(torch.tensor([[most_similar_token_idx]]), max_length=20)
    # word       = tokenizer_.decode(output[0], skip_special_tokens=True)
    # print(word)

    sentence += '' + word


prima
needs
meter
meter
desk
point
point
progress
even
meters
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
