In [1]:
import json
import torch
import torch.nn as nn
import os
from tqdm.notebook import tqdm
from collections import Counter

from torch.utils.data import Dataset, DataLoader, ConcatDataset
from transformers import BertTokenizer, BertConfig, BertModel

In [2]:
seed = 0
max_len = 256

In [3]:
# train_json = json.load(open('data/train_set.json'))

In [4]:
# list(train_json.values())[-1]

In [5]:
# Custom Dataset class for loading PubMedQA data from JSON files
class PubMedQADataset(Dataset):
    def __init__(self, json_file, tokenizer, max_length=512, labeled=True):
        self.data = list(json.load(open(json_file)).values())
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.labeled = labeled

        self.label_map = {
            "yes": 0,
            "no": 1,
            "maybe": 2,
            "no_label":3
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx]['QUESTION']
        context = self.data[idx]['CONTEXTS']
        inputs = self.tokenizer(question, context, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        
        item = {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze()
        }

        if self.labeled:
            label = self.data[idx]['final_decision']
            item['labels'] = torch.tensor(self.label_map[label])
        else:
            item['labels'] = torch.tensor(self.label_map["no_label"])

        return item



In [6]:

# Load datasets
tokenizer = BertTokenizer.from_pretrained('nlpie/tiny-biobert')
labeled_dataset = PubMedQADataset('data/train_set.json', tokenizer,max_length=max_len, labeled=True)
artificial_dataset = PubMedQADataset('data/ori_pqaa.json',tokenizer,max_length=max_len,labeled=True)
unlabeled_dataset = PubMedQADataset('data/ori_pqau.json', tokenizer, max_length=max_len,labeled=False)

# Create DataLoaders
labeled_loader = DataLoader(labeled_dataset, batch_size=8, shuffle=True)
artificial_loader =  DataLoader(artificial_dataset, batch_size=8, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)


In [7]:

class TransformerGenerator(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super(TransformerGenerator, self).__init__()
        self.config = BertConfig.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name, config=self.config)
        self.fc = nn.Linear(self.config.hidden_size, self.config.vocab_size)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.fc(hidden_states)
        return logits


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

generator = TransformerGenerator('nlpie/tiny-biobert').to('cuda')


Some weights of BertModel were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from transformers import BertForSequenceClassification


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

discriminator = BertForSequenceClassification.from_pretrained('nlpie/tiny-biobert', num_labels=4).to('cuda')


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpie/tiny-biobert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:

# Concatenate labeled and unlabeled datasets for joint training
combined_dataset = ConcatDataset([labeled_dataset, artificial_dataset,unlabeled_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=8, shuffle=True)

In [10]:
def compute_class_counts(dataset):
    class_counts = Counter()
    for data in tqdm(dataset):
        label = data['labels'].item()
        class_counts[label] += 1
    return class_counts

# Compute class counts and weights
class_counts = compute_class_counts(combined_dataset)


  0%|          | 0/273018 [00:00<?, ?it/s]

In [11]:
class_count_list = [class_counts[i] for i in range(len(class_counts))]
class_weights = [max(class_count_list) / count for count in class_count_list]
class_weights = torch.tensor(class_weights, dtype=torch.float).to('cuda')

# Create the criterion with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [12]:
class_counts,class_count_list

(Counter({0: 196420, 3: 61249, 1: 15294, 2: 55}), [196420, 15294, 55, 61249])

In [13]:
class_weights

tensor([1.0000e+00, 1.2843e+01, 3.5713e+03, 3.2069e+00], device='cuda:0')

In [14]:
from torch.optim import AdamW


# Optimizers
g_optimizer = AdamW(generator.parameters(), lr=1e-4)
d_optimizer = AdamW(discriminator.parameters(), lr=1e-5)

In [20]:
g_start_epoch

3

In [21]:


# # Loss function
# criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 1

for epoch in range(g_start_epoch,g_start_epoch+num_epochs):
    discriminator.train()
    generator.train()
    progress_bar = tqdm(combined_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")

    for batch in progress_bar:
        real_data = batch['input_ids'].to('cuda')
        attention_mask = batch['attention_mask'].to('cuda')
        
        # If the batch contains labeled data, use labels; otherwise, generate fake labels
        if 'labels' in batch:
            labels = batch['labels'].to('cuda')
        else:
            labels = torch.zeros(real_data.size(0), dtype=torch.long).to('cuda')

        # Discriminator forward pass
        d_optimizer.zero_grad()
        outputs = discriminator(input_ids=real_data, attention_mask=attention_mask).logits
        d_loss_real = criterion(outputs, labels)

        # Generator forward pass
        noise = torch.randint(0, generator.config.vocab_size, (real_data.size(0), real_data.size(1))).to('cuda')
        fake_data_logits = generator(input_ids=noise, attention_mask=(noise != tokenizer.pad_token_id).to('cuda'))
        fake_data = torch.argmax(fake_data_logits, dim=-1)
        fake_labels = torch.zeros(real_data.size(0), dtype=torch.long).to('cuda')  # Fake labels are zeros

        d_loss_fake = criterion(discriminator(input_ids=fake_data, attention_mask=(fake_data != tokenizer.pad_token_id).to('cuda')).logits, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # Generator training
        g_optimizer.zero_grad()
        noise = torch.randint(0, generator.config.vocab_size, (real_data.size(0), real_data.size(1))).to('cuda')
        fake_data_logits = generator(input_ids=noise, attention_mask=(noise != tokenizer.pad_token_id).to('cuda'))
        fake_data = torch.argmax(fake_data_logits, dim=-1)
        g_loss = criterion(discriminator(input_ids=fake_data, attention_mask=(fake_data != tokenizer.pad_token_id).to('cuda')).logits, fake_labels)
        g_loss.backward()
        g_optimizer.step() 

    print(f"Epoch {epoch + 1}/{num_epochs} | Discriminator Loss: {d_loss.item()} | Generator Loss: {g_loss.item()}")


Epoch 4/1:   0%|          | 0/34128 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 4/1 | Discriminator Loss: 0.38557541370391846 | Generator Loss: 0.0003756763762794435


In [16]:
outputs[:,:-1].shape

0.82, 0.00

torch.Size([20, 3])

In [22]:
# Load test dataset
test_dataset = PubMedQADataset('data/test_set.json', tokenizer, labeled=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Evaluation
discriminator.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        inputs = batch['input_ids'].to('cuda')
        labels = batch['labels'].to('cuda')
        outputs = discriminator(input_ids=inputs, attention_mask=batch['attention_mask'].to('cuda')).logits
        _, predicted = torch.max(outputs[:,:-1], 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")


Test Accuracy: 40.60%


In [23]:


# Evaluation
discriminator.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in labeled_loader:
        inputs = batch['input_ids'].to('cuda')
        labels = batch['labels'].to('cuda')
        outputs = discriminator(input_ids=inputs, attention_mask=batch['attention_mask'].to('cuda')).logits
        _, predicted = torch.max(outputs[:,:-1], 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Labelled Train Accuracy: {accuracy * 100:.2f}%")


Labelled Train Accuracy: 39.80%


In [24]:
discriminator.eval()
correct = 0
total = 0

artificial_loader =  DataLoader(artificial_dataset, batch_size=512, shuffle=True)


with torch.no_grad():
    for batch in tqdm(artificial_loader):
        inputs = batch['input_ids'].to('cuda')
        labels = batch['labels'].to('cuda')
        outputs = discriminator(input_ids=inputs, attention_mask=batch['attention_mask'].to('cuda')).logits
        _, predicted = torch.max(outputs[:,:-1], 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Artificial Train Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/413 [00:00<?, ?it/s]

Artificial Train Accuracy: 92.20%


In [None]:

                                    Epoch
Artificial Train Accuracy: 59.19%,87.26%,89.23%,92.20
Labelled Train Accuracy: 31.80%,37%,41.60%,39.80
Test Accuracy: 30.60%,35.., 44.40%,40.60

In [16]:
save_dir = "weights"
generator_model_path = os.path.join(save_dir, "generator_model.pth")
discriminator_model_path = os.path.join(save_dir, "discriminator_model.pth")
g_optimizer_path = os.path.join(save_dir, "g_optimizer.pth")
d_optimizer_path = os.path.join(save_dir, "d_optimizer.pth")



In [26]:

torch.save({
        'epoch': 2,
        'model_state_dict': generator.state_dict(),
        'optimizer_state_dict': g_optimizer.state_dict(),
        'loss': g_loss.item()
    }, generator_model_path)

torch.save({
        'epoch': 2,
        'model_state_dict': discriminator.state_dict(),
        'optimizer_state_dict': d_optimizer.state_dict(),
        'loss': d_loss.item()
    }, discriminator_model_path)

torch.save(g_optimizer.state_dict(), g_optimizer_path)
torch.save(d_optimizer.state_dict(), d_optimizer_path)

In [17]:
# Loading model and optimizer state dicts
checkpoint = torch.load(generator_model_path)
generator.load_state_dict(checkpoint['model_state_dict'])
g_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
g_start_epoch = checkpoint['epoch'] + 1

checkpoint = torch.load(discriminator_model_path)
discriminator.load_state_dict(checkpoint['model_state_dict'])
d_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
d_start_epoch = checkpoint['epoch'] + 1


In [None]:
Test Accuracy: 30.60%