In [None]:
!pip install torchmetrics

In [None]:
!pip install pytorch_lightning

In [None]:
!pip install transformers

In [None]:
import pandas as pd 
from matplotlib import pyplot as plt

import torch 
from torch import nn, optim 
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from torchmetrics.functional import accuracy , f1_score

import pytorch_lightning as pl 
from pytorch_lightning.callbacks import TQDMProgressBar

import transformers 
from transformers import BertModel, BertConfig, AutoModel, BertTokenizerFast , RobertaModel , RobertaConfig , RobertaTokenizer

In [None]:
pub_health_train = pd.read_csv('drive/MyDrive/pubhealth/data/clean/exp1/clean_train.tsv', sep='\t')
pub_health_test  = pd.read_csv('drive/MyDrive/pubhealth/data/clean/exp1/clean_dev.tsv', sep='\t')

In [None]:
pub_health_train.head()

Unnamed: 0,label,claim,main_text,explanation,subjects,sources
0,false,"""The money the Clinton Foundation took from fr...","""Hillary Clinton is in the political crosshair...","""Gingrich said the Clinton Foundation """"took m...","Foreign Policy, PunditFact, Newt Gingrich,",https://www.wsj.com/articles/clinton-foundatio...
1,mixture,Annual Mammograms May Have More False-Positives,While the financial costs of screening mammogr...,This article reports on the results of a study...,"Screening,WebMD,women's health",
2,mixture,SBRT Offers Prostate Cancer Patients High Canc...,The news release quotes lead researcher Robert...,This news release describes five-year outcomes...,"Association/Society news release,Cancer",https://www.healthnewsreview.org/wp-content/up...
3,true,"Study: Vaccine for Breast, Ovarian Cancer Has ...","The story does discuss costs, but the framing ...","While the story does many things well, the ove...","Cancer,WebMD,women's health",http://clinicaltrials.gov/ct2/results?term=can...
4,true,Some appendicitis cases may not require ’emerg...,"""Although the story didn’t cite the cost of ap...",We really don’t understand why only a handful ...,,


In [None]:
pub_health_train['claim'][0]

'"The money the Clinton Foundation took from from foreign governments while Hillary Clinton was secretary of state ""is clearly illegal. … The Constitution says you can’t take this stuff."'

In [None]:
pub_health_train['explanation'][0]

'"Gingrich said the Clinton Foundation ""took money from from foreign governments while (Hillary Clinton) was secretary of state. It is clearly illegal. … The Constitution says you can’t take this stuff."" A clause in the Constitution does prohibit U.S. officials such as former Secretary of State Hillary Clinton from receiving gifts, or emoluments, from foreign governments. But the gifts in this case were donations from foreign governments that went to the Clinton Foundation, not Hillary Clinton. She was not part of the foundation her husband founded while she was secretary of state. Does that violate the Constitution? Some libertarian-minded constitutional law experts say it very well could. Others are skeptical. What’s clear is there is room for ambiguity, and the donations are anything but ""clearly illegal."" The reality is this a hazy part of U.S. constitutional\xa0law. '

In [None]:
pub_health_train['main_text'][0]

'"Hillary Clinton is in the political crosshairs as the author of a new book alleges improper financial ties between her public and personal life. At issue in conservative author Peter Schweizer’s forthcoming book Clinton Cash are donations from foreign governments to the Clinton Foundation during the four years she served as secretary of state. George Stephanopoulos used an interview with Schweizer on ABC This Week to point out what other nonpartisan journalists have found: There is no ""smoking gun"" showing that donations to the foundation influenced her foreign policy decisions. Still, former Republican House Speaker Newt Gingrich says the donations are ""clearly illegal"" under federal law. In his view, a donation by a foreign government to the Clinton Foundation while Clinton was secretary of state is the same as money sent directly to her, he said, even though she did not join the foundation’s board until she left her post. ""The Constitution of the United States says you cannot

In [None]:
pub_health_train['label'] = pub_health_train['label'].map(
    {"true":0, "false":1, "unproven":2, "mixture":3}
)
pub_health_test['label'] = pub_health_test['label'].map(
    {"true":0, "false":1, "unproven":2, "mixture":3}
)

In [None]:
class HealthClaimClassifier(pl.LightningModule):

    def __init__(self, max_seq_len=512, batch_size=128, learning_rate=1e-3):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        
        self.loss = nn.CrossEntropyLoss()

        self.pretrain_model = AutoModel.from_pretrained('bert-base-uncased', return_dict=False)
        self.pretrain_model.eval() 
        for param in self.pretrain_model.parameters():
            param.requires_grad = False
        
        self.new_layers = nn.Sequential(
            nn.Linear(768, 512), 
            nn.ReLU(),
            nn.Dropout(0.2), 
            nn.Linear(512, 4),
            nn.LogSoftmax(dim=1),
        )
    def prepare_data(self):
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', return_dict=False)
        
        tokens_train = tokenizer.batch_encode_plus(
            pub_health_train['main_text'].tolist(),
            max_length=self.max_seq_len, 
            pad_to_max_length=True, 
            truncation=True, 
            return_token_type_ids=False
        )

        tokens_test = tokenizer.batch_encode_plus(            
            pub_health_test['main_text'].tolist(),
            max_length=self.max_seq_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=False
        )

        self.train_seq = torch.tensor(tokens_train['input_ids'])
        self.train_mask = torch.tensor(tokens_train['attention_mask'])
        self.train_y = torch.tensor(pub_health_train['label'].tolist())

        self.test_seq = torch.tensor(tokens_test['input_ids'])
        self.test_mask = torch.tensor(tokens_test['attention_mask'])
        self.test_y = torch.tensor(pub_health_test['label'].tolist())

    def forward(self, seq, mask):
        _, output = self.pretrain_model(seq, attention_mask=mask)
        output = self.new_layers(output)
        return output
    
    def train_dataloader(self):
        train_dataset = TensorDataset(self.train_seq, self.train_mask, self.train_y)
        self.train_dataloader_obj = DataLoader(train_dataset, batch_size=self.batch_size)
        return self.train_dataloader_obj
    
    def test_dataloader(self):
        test_dataset = TensorDataset(self.test_seq, self.test_mask, self.test_y)
        self.test_dataloader_obj = DataLoader(test_dataset, batch_size=self.batch_size)
        return self.test_dataloader_obj
    
    def training_step(self, batch, batch_idx):
        seq, mask, targets = batch 

        outputs = self(seq, mask)
        preds = torch.argmax(outputs, dim=1)
        
        train_accuracy = accuracy(preds, targets,task="multiclass", num_classes=4)
        train_f1core = f1_score(preds, targets,task="multiclass", num_classes=4)

        loss = self.loss(outputs, targets)

        self.log('train_accuracy', train_accuracy, prog_bar=True, on_step=True, on_epoch=True)
        self.log('train_f1core', train_f1core, prog_bar=True, on_step=True, on_epoch=True)
        self.log('train_loss', loss, on_step=True, on_epoch=True)


        return {'loss': loss, 'train_accuracy': train_accuracy , 'train_f1core':train_f1core}
    
    def test_step(self, batch, batch_idx):
        seq, mask, targets = batch 

        outputs = self.forward(seq, mask)
        preds = torch.argmax(outputs, dim=1)

        test_accuracy = accuracy(preds, targets,task="multiclass", num_classes=4)
        test_f1core = f1_score(preds, targets,task="multiclass", num_classes=4)
        loss = self.loss(outputs, targets)
        return {'loss': loss, 'test_accuracy': test_accuracy,'test_f1core':test_f1core}

    def test_epoch_end(self, outputs):
        test_outs = []
        for test_out in outputs:
            out = test_out['test_accuracy']
            test_outs.append(out)
        total_test_accuracy = torch.stack(test_outs).mean()
        self.log('total_test_accuracy', total_test_accuracy, on_step=True, on_epoch=True)
        return total_test_accuracy
    
    def configure_optimizers(self):
        params = self.parameters()
        optimizer = torch.optim.Adam(params=params, lr=self.learning_rate)
        return optimizer

In [None]:
# model = HealthClaimClassifier()
# trainer = pl.Trainer(fast_dev_run=True, accelerator='gpu', devices=1)


In [None]:
model = HealthClaimClassifier()
bar = TQDMProgressBar()
trainer = pl.Trainer(default_root_dir='BERT_transfer_learning_ckpts', max_epochs=10, accelerator='gpu', devices=1)
histo = trainer.fit(model)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]



TypeError: ignored

In [None]:
# model = HealthClaimClassifier()
# bar = TQDMProgressBar()
# trainer = pl.Trainer(default_root_dir='BERT_transfer_learning_ckpts', max_epochs=10, accelerator='gpu', devices=1, callbacks=[bar])
# histo = trainer.fit(model)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0

Training: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


In [None]:
type(histo)

NoneType

In [None]:
trainer.test()

  rank_zero_warn(
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at BERT_transfer_learning_ckpts/lightning_logs/version_0/checkpoints/epoch=9-step=770.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from checkpoint at BERT_transfer_learning_ckpts/lightning_logs/version_0/checkpoints/epoch=9-step=770.ckpt


Testing: 0it [00:00, ?it/s]

TypeError: ignored