In [None]:
!pip install transformers
!pip install datasets
!pip install peft

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from torch import nn
import pandas as pd
from torch.utils.data import DataLoader
import re
from torch.optim import AdamW
from tqdm import tqdm, trange
from peft import LoraConfig, TaskType
from peft import get_peft_model
from datasets import load_dataset, Features, Value
from torch.nn import functional as F
import random
import numpy as np
import os
def set_random_seed(seed):
      random.seed(seed)
      np.random.seed(seed)
      torch.manual_seed(seed)
      torch.random.manual_seed(seed)
      torch.cuda.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)
      torch.backends.cudnn.deterministic = True

set_random_seed(123)

In [None]:


plm = "EleutherAI/pythia-160m-deduped"

bos = '<|endoftext|>'
eos = '<|END|>'
pad = '<|pad|>'
sep = '\n\n####\n\n'

special_tokens_dict = {'eos_token': eos, 'bos_token': bos, 'pad_token': pad, 'sep_token': sep}

tokenizer = AutoTokenizer.from_pretrained(plm, revision="step3000")
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [None]:
PAD_IDX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
IGNORED_PAD_IDX = -100
PAD_IDX

In [6]:
files="/Data/train_datas.csv"
datas = pd.read_csv(files)

In [None]:

def modify_string(original_string):
    pattern = re.compile(r'(\w+)\t(\d+)\t(\d+)\t(.+)')
    modified_string = re.sub(pattern, r'\1\t\4', original_string)
    return modified_string

train_data = [bos + data['contents'] + sep + modify_string(data['labels']) + eos for _, data in datas.iterrows()]
for i in range(100):
  print(train_data[i])
print("============================================")
print(len(train_data))

def collate_batch(batch):

    texts = batch

    encoded_seq = tokenizer(texts, padding=True)

    indexed_tks = torch.tensor(encoded_seq['input_ids'])
    attention_mask = torch.tensor(encoded_seq['attention_mask'])
    encoded_label = torch.tensor(encoded_seq['input_ids'])

    encoded_label[encoded_label == tokenizer.pad_token_id] = IGNORED_PAD_IDX

    return indexed_tks, encoded_label, attention_mask

train_dataloader = DataLoader(train_data, batch_size=2, shuffle=False, collate_fn=collate_batch)
titer = iter(train_dataloader)
tks, labels, mask = next(titer)
print(tks.shape)
next(iter(titer))

In [None]:

BATCH_SIZE = 8
set_random_seed(123)

class BatchSampler():
    def __init__(self, data, batch_size):
        self.pooled_indice = []
        self.data = data
        self.batch_size = batch_size
        self.len = len(list(data))
    def __iter__(self):
        self.pooled_indices = []
        indices = [(index, len(data)) for index, data in enumerate(self.data)]
        random.shuffle(indices)
        for i in range(0, len(indices), BATCH_SIZE * 100):
            self.pooled_indices.extend(sorted(indices[i:i + BATCH_SIZE * 100], key=lambda x: x[1], reverse=True))
        self.pooled_indices = [x[0] for x in self.pooled_indices]

        for i in range(0, len(self.pooled_indices), BATCH_SIZE):
            yield self.pooled_indices[i:i + BATCH_SIZE]
    def __len__(self):
        return (self.len + self.batch_size - 1) // self.batch_size

bucket_train_dataloader = DataLoader(train_data, batch_sampler=BatchSampler(train_data, BATCH_SIZE),
                                     collate_fn=collate_batch, pin_memory=True)

In [None]:
set_random_seed(123)
model = AutoModelForCausalLM.from_pretrained(plm, revision='step3000')
optimizer = AdamW(model.parameters(), lr=5e-5)
model.resize_token_embeddings(len(tokenizer))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
'''
checkpoint_filename = os.path.join('/models/pythia160md', "model_epoch_2_lora_1.pt")
checkpoint = torch.load(checkpoint_filename)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
train_loss = checkpoint['train_loss']
'''

In [None]:
EPOCHS = 2
set_random_seed(123)
model.train()
for epo in trange(EPOCHS, desc="Epoch"):
    set_random_seed(123)
    model.train()
    total_loss = 0

    predictions, true_labels = [], []

    for step, (seqs, labels, masks) in enumerate(bucket_train_dataloader):
        seqs = seqs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        model.zero_grad()
        outputs = model(seqs, labels=labels)
        logits = outputs.logits
        loss = outputs.loss
        loss = loss.mean()

        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_train_loss = total_loss / len(bucket_train_dataloader)
    print("Epoch {}: Average train loss: {}".format(epo + 1, avg_train_loss))
'''
    checkpoint_filename = os.path.join('/models/pythia160md', f"model_epoch_{epo + 1 + epoch}.pt")
    torch.save({
        'epoch': epo + 1 + epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
    }, checkpoint_filename)
'''

In [None]:

set_random_seed(123)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

model = get_peft_model(model, peft_config)

optimizer = AdamW(model.parameters(), lr=5e-5)
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.print_trainable_parameters()
model.to(device)

In [None]:
EPOCHS = 1
set_random_seed(123)
model.train()
for _ in trange(EPOCHS, desc="Epoch"):
    set_random_seed(123)
    model.train()
    total_loss = 0

    predictions, true_labels = [], []

    for step, (seqs, labels, masks) in enumerate(bucket_train_dataloader):
        seqs = seqs.to(device)
        labels = labels.to(device)
        masks = masks.to(device)
        model.zero_grad()
        outputs = model(seqs, attention_mask=masks, labels=labels)
        logits = outputs.logits
        loss = outputs.loss
        loss = loss.mean()

        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    avg_train_loss = total_loss / len(bucket_train_dataloader)
    print("Average train loss: {}".format(avg_train_loss))

In [None]:
model = model.merge_and_unload()

In [None]:
checkpoint_filename = os.path.join('/models/pythia160md', "model_epoch_2_lora_1.pt")
torch.save({
    'epoch': 3,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_loss': avg_train_loss,
}, checkpoint_filename)

In [None]:

valid_data = load_dataset("csv", data_files="/Data/opendid_valid.tsv", delimiter='\t',
                          features = Features({
                              'fid': Value('string'),
                              'idx': Value('int64'),
                              'content': Value('string')}),
                              column_names=['fid', 'idx', 'content'])
valid_list = list(valid_data['train'])

In [None]:


dic = [ 'PATIENT', 'DOCTOR', 'USERNAME', 'PROFESSION', 'ROOM', 'DEPARTMENT', 'HOSPITAL', 'ORGANIZATION', 'STREET', 'CITY', 'STATE', 'COUNTRY', 'ZIP', 'LOCATION-OTHER', 'AGE', 'DATE',
    'TIME', 'DURATION', 'SET', 'PHONE', 'FAX', 'EMAIL', 'URL', 'IPADDR', 'SSN', 'MEDICALRECORD', 'HEALTHPLAN', 'ACCOUNT', 'LICENSE', 'VECHICLE', 'DEVICE', 'BIOID', 'IDNUM', 'OTHER']

du = r'P\d\d?(\.\d\d?)?W|P\d\d?(\.\d\d?)?K|P\d\d?(\.\d\d?)?Y|P\d\d?(\.\d\d?)?D|P\d\d?(\.\d\d?)?M'

def predict(model, tokenizer, input, max_new_tokens=400):

    def mondify(string, file, start, input):
        nonlocal end

        def found_position(string, found, start):
            nonlocal end
            position = string.find(found, end)

            if position != -1:
                end = position + len(found)
                res = f'{start+position}\t{start+position+len(found)}'
                return res

        pattern = re.compile(r'(\w+)\t(.+)')
        content = re.sub(pattern, r'\2', string)
        PHI = re.sub(pattern, r'\1', string)

        if PHI not in dic:
            return ""

        if PHI == 'DATE':
            cont = content.split('\t')
            start_end = found_position(input, cont[0], int(start))
            if len(cont) >= 2:
                if len(cont[1]) > 10:
                    cont[1] = cont[1][:10]
                content = cont[0] + '\t' + cont[1]
            else:
              return ""

        elif PHI == 'TIME':
            cont = content.split('\t')
            start_end = found_position(input, cont[0], int(start))
            if len(cont) >= 2:
                if len(cont[1]) > 16:
                    cont[1] = cont[1][:16]
                content = cont[0] + '\t' + cont[1]
            else:
              return ""
        elif PHI == 'DURATION':
            cont = content.split('\t')
            start_end = found_position(input, cont[0], int(start))
            if len(cont) >= 2:
                if len(cont[1]) > 4:
                    cont[1] = cont[1][:4]
                if not re.match(du, cont[1]):
                  return ""
                content = cont[0] + '\t' + re.match(du, cont[1]).group()
            else:
              return ""

        else:
            start_end = found_position(input, content, int(start))

        if start_end is None:
            return ""

        modified = re.sub(pattern, r'{}\t\1\t{}\t\2'.format(file, start_end), PHI + '\t' + content)
        return modified

    model.eval()
    input = input.split('<-#-#-#->')
    prompt = bos + input[2] + sep
    end = 0

    tks_info = tokenizer(prompt)
    text = tks_info['input_ids']

    inputs, past_key_values = torch.tensor([text]), None
    outputs = ''

    with torch.no_grad():
        for _ in range(max_new_tokens):
            out = model(inputs.to(device), past_key_values=past_key_values)
            logits = out.logits
            past_key_values = out.past_key_values
            log_probs = F.softmax(logits[:, -1], dim=-1)
            inputs = torch.argmax(log_probs, 1).unsqueeze(0)

            if tokenizer.decode(inputs.item()) == eos:
                break
            text.append(inputs.item())

        pred = tokenizer.decode(text)
        pred = pred[pred.find(sep) + len(sep):].replace(pad, "").replace(eos, "").strip()

        if pred.split('\t')[-1] == 'PHI: NULL':
            return ""

        labels = pred.split('\n')

        for i in range(len(labels)):
            if i < len(labels) - 1:
                if len(labels[i].split("\t")) < 2:
                    continue
                o = mondify(labels[i], input[0], input[1], input[2])
                if o == "":
                    continue
                outputs += o
                if end > len(input[2]):
                    break
                outputs += '\n'
            else:
                if len(labels[i].split("\t")) < 2:
                    continue
                o = mondify(labels[i], input[0], input[1], input[2])
                if o == "":
                    continue
                outputs += o

        outputs = outputs.rstrip('\n')

    return outputs


In [None]:
answer = []

lo = r'P\.?\ ?O\.?\ ?BOX \d+'
co = ['South Africa', 'Vietnam', 'USA', 'Australia']

for sent in valid_list:

  input_ = sent['content']
  if input_ is None: continue
  ans = predict(model, tokenizer, sent['fid']+'<-#-#-#->'+str(sent['idx'])+'<-#-#-#->'+input_)

  for c in co:
    if ans.find('ZIP') > -1: break
    pos1 = input_.find(c)
    if pos1 > -1:
      pos1 += sent['idx']
      if ans is None or ans == "":
        ans = sent['fid']+'\t'+'COUNTRY'+'\t'+str(pos1)+'\t'+str(pos1+len(c))+'\t'+c
        break
      a1 = sent['idx']
      a2 = ans.split('\n')
      for i in range(len(a2)):
        if pos1+len(c) <= int(a2[i].split('\t')[2]) and pos1 >= a1:
          a2.insert(i, sent['fid']+'\t'+'COUNTRY'+'\t'+str(pos1)+'\t'+str(pos1+len(c))+'\t'+c)
          ans = '\n'.join(a2)
          break
        a1 = int(a2[i].split('\t')[3])
      a1 = int(a2[-1].split('\t')[3])
      if pos1 >= a1:
        print(a2)
        a2.append(sent['fid']+'\t'+'COUNTRY'+'\t'+str(pos1)+'\t'+str(pos1+len(c))+'\t'+c)
        ans = '\n'.join(a2)
        print(a2)
        break

      for j in range(len(a2)):
        if c == a2[j].split('\t')[-1]:
          a2[j] = sent['fid']+'\t'+'COUNTRY'+'\t'+str(pos1)+'\t'+str(pos1+len(c))+'\t'+c
          ans = '\n'.join(a2)
          break


  match = re.search(lo, input_)
  if match:
    loc = match.group()
    pos = input_.find(loc)+sent['idx']
    if ans is None or ans == "":
      ans = sent['fid']+'\t'+'LOCATION-OTHER'+'\t'+str(pos)+'\t'+str(pos+len(loc))+'\t'+loc
      answer.append(str(ans)+'\n')
      print(ans)
      continue
    a1 = sent['idx']
    a2 = ans.split('\n')
    for i in range(len(a2)):
      if pos+len(loc) <= int(a2[i].split('\t')[2]) and pos >= a1:
        a2.insert(i, sent['fid']+'\t'+'LOCATION-OTHER'+'\t'+str(pos)+'\t'+str(pos+len(loc))+'\t'+loc)
        ans = '\n'.join(a2)
        break
      a1 = int(a2[i].split('\t')[3])

  if ans is None or ans == "": continue
  answer.append(str(ans)+'\n')
  print(ans)

with open('/Data/answer.txt', 'w') as file:
    file.writelines(answer)