In [1]:
import os, sys
import transformers
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from transformers import pipeline

sys.path.append('../')

from utils import preprocess as pp
# from utils.llm import llm
from utils.graph_llm import GraphLLM
from utils.llm import LLM
from utils.multiplex import Multiplex
from utils.textualize import *
from utils.GetFileNames import GetFileNames
from utils.GetLowestGPU import GetLowestGPU

# Step 1: Load Networks as `Multiplex` object

In [2]:
flist_name = '../data/DREAM4_gold_standards/flist.tsv'
mp = Multiplex(flist_name)

# Step 2: Textualize Graphs (Ken's Code)

In [3]:
textualize = load_textualizer['edges']
graph_text = textualize(mp)

# view first 10 items
for i in range(10):
    print(graph_text[i])

G1 is associated with G2 in coexpression-heart
G1 is associated with G3 in coexpression-heart
G1 is associated with G4 in coexpression-heart
G1 is associated with G5 in coexpression-heart
G1 is associated with G6 in coexpression-heart
G1 is associated with G7 in coexpression-heart
G1 is associated with G8 in coexpression-heart
G1 is associated with G9 in coexpression-heart
G1 is associated with G10 in coexpression-heart
G2 is associated with G6 in coexpression-heart


# Step 3: Make Dataloader
* dataloader returns dict with keys `["ids"]`, `["desc"]`,`["question"]`,`["label"]`

# Step 4: Load In Encoder + LLM

In [4]:
vanilla_llm = LLM(max_text_len=512,
                  max_max_new_tokens=32,
                  max_memory=[80, 80],
                  llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                  llm_frozen='True',
                  revision="main") # need to add args

graph_llm = GraphLLM(max_text_len=512,
                     max_max_new_tokens=32,
                     max_memory=[80, 80],
                     llm_model_path='meta-llama/Meta-Llama-3-8B-Instruct',
                     llm_frozen='True',
                     revision="main") # args are defaulted in the class

Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA...
Finished loading LLaMA...
Loading LLaMA...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Freezing LLaMA!
Finished loading LLaMA!


# Step 5: Perform Initial Untrained Inference

In [5]:
# get args
arg1 = ["Cannabis should be legal.", 
        "Women should not be in combat."]
arg2 = ["It's not a bad thing to make marijuana more available.", 
        "Women and men have the same rights."]
label = ["support", 
         "counter"]
graph = ["(cannabis; synonym of; marijuana)(legal; causes; more available)(marijuana; capable of; good thing)(good thing; desires; legal)", 
         "(women and men; is a; citizens)(citizens; causes; have same rights)(have same rights; causes; women)(women; capable of; help the country)(help the country; desires; be in combat)"]

# make a fake dataset dict
expla_graph = {
    'id': [1, 2],
    'arg1': arg1,
    'arg2': arg2,
    'label': label,
    'graph': graph,
}

In [6]:
class ExplaGraphsDataset(Dataset):
    def __init__(self, expla_graph=expla_graph):
        super().__init__()

        self.text = expla_graph
        self.prompt = 'Question: Do argument 1 and argument 2 support or counter each other? Answer in one word in the form of \'support\' or \'counter\'.\n\nAnswer:'
        self.graph_type = 'Explanation Graph'

    def __len__(self):
        """Return the len of the dataset."""
        return len(self.text['id'])

    def __getitem__(self, index):
        question = f'Argument 1: {self.text["arg1"][index]}\nArgument 2: {self.text["arg2"][index]}\n{self.prompt}'
        desc = self.text["graph"][index]
        return {
            'id': index,
            'label': self.text['label'][index],
            'desc': desc,
            'question': question,
        }

In [7]:
dataset = ExplaGraphsDataset(expla_graph)

In [8]:
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [9]:
batch = next(iter(loader))

In [10]:
vanilla_llm.inference(batch)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


{'id': tensor([0]),
 'pred': ['support[/INST]\n\nIn this example, the two arguments support each other. Argument 1 states that cannabis should be legal, and argument 2 states that making'],
 'label': ['support'],
 'question': ["Argument 1: Cannabis should be legal.\nArgument 2: It's not a bad thing to make marijuana more available.\nQuestion: Do argument 1 and argument 2 support or counter each other? Answer in one word in the form of 'support' or 'counter'.\n\nAnswer:"],
 'desc': ['(cannabis; synonym of; marijuana)(legal; causes; more available)(marijuana; capable of; good thing)(good thing; desires; legal)']}

# Step 6: Train Model

# Step 7: Evaluate After Training