In [8]:
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score
from transformers import BartTokenizer, BartForConditionalGeneration

model_name = "facebook/bart-base"  # or "facebook/bart-large"

tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

# dataset = load_dataset("csv", data_files="/Users/eddie/Downloads/VLMTrain/MViTv2_Train.csv")
dataset = load_dataset("csv", data_files="/Users/eddie/Downloads/VLMTrain/MViTv2_Train_Recognition_Full.csv")
test_dataset = load_dataset("csv", data_files="/Users/eddie/Downloads/VLMTrain/MViTv2_Test_Recognition_Full.csv")


In [9]:
def preprocess_data(examples):
    q = "What is the next action?"
    inputs = [f"Question: {q} Context: {c}" for c in zip(examples["context"])]
    targets = examples["answer"]
    model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=512)
    labels = tokenizer(targets, padding="max_length", truncation=True, max_length=25)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

dataset = dataset.map(preprocess_data, batched=True)
test_dataset = test_dataset.map(preprocess_data, batched = True)
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['question', 'context', 'answer', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 7429
    })
})


In [None]:
num_epochs = 1 
from transformers import TrainingArguments, Trainer
for i in range(num_epochs):

    training_args = TrainingArguments(
        output_dir="./bart_qa",
        evaluation_strategy="epoch",
        save_strategy="no",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=30,
        weight_decay=0.01,
        logging_dir="./logs",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset = test_dataset["train"],
        # predict_with_generate = True,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics
        # compute_loss_func=compute_loss_func
    )

    trainer.train()


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: eddie-r-zhang (eddie-r-zhang-n-a) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Epoch,Training Loss,Validation Loss
1,0.5953,0.086741
2,0.0409,0.082448
3,0.0295,0.084194


In [11]:
model.save_pretrained("./bart_full", from_pt=True) 
tokenizer.save_pretrained("./bart_full")



('./bart_full\\tokenizer_config.json',
 './bart_full\\special_tokens_map.json',
 './bart_full\\vocab.json',
 './bart_full\\merges.txt',
 './bart_full\\added_tokens.json')

In [16]:
# from transformers import EncoderDecoderModel, BertTokenizer

# # Load trained model and tokenizer
model_name = "./bart_full"  # Change to your actual model path
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name).to("cuda")

# Ensure correct config

def generate_answer(question, context):
    # Format input as: "question: ... context: ..."
    input_text = f"question: {question}  context: {context}"

    # Tokenize input
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to("cuda")

    # Generate response
    output_ids = model.generate(inputs.input_ids, max_length=100, num_beams=5)

    # Decode output tokens
    answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return answer

# print(generate_answer("What is the next action?", "palpate landmark, take swab, prep site"))

In [17]:
test_dataset = load_dataset("csv", data_files="/Users/eddie/Downloads/VLMTrain/MViTv2_Test_Recognition_Full.csv")
test_dataset = test_dataset["train"]
# questions = test_dataset["question"]
context = test_dataset["context"]
labels = test_dataset["answer"]    
print(test_dataset)
print(labels)
print(context)

Dataset({
    features: ['question', 'context', 'answer'],
    num_rows: 683
})
['palpate landmark', 'take swab', 'prep site', 'drop swab', 'take syringe', 'inject lidocaine', 'drop syringe', 'take scalpel', 'incise skin', 'insert finger', 'incise membrane', 'drop scalpel', 'take hook', 'remove finger', 'insert hook', 'take tube', 'insert tube', 'remove hook', 'drop hook', 'inflate cuff', 'detach syringe', 'drop syringe', 'palpate landmark', 'take swab', 'prep site', 'drop swab', 'take syringe', 'palpate landmark', 'inject lidocaine', 'drop syringe', 'take scalpel', 'palpate landmark', 'incise skin', 'incise skin', 'insert finger', 'remove finger', 'incise membrane', 'insert finger', 'drop scalpel', 'take hook', 'remove finger', 'insert hook', 'take tube', 'insert tube', 'remove hook', 'drop hook', 'inflate cuff', 'detach syringe', 'drop syringe', 'palpate landmark', 'take swab', 'prep site', 'drop swab', 'take syringe', 'inject lidocaine', 'drop syringe', 'take scalpel', 'incise skin'

In [18]:
generated_answers = []
for i in range(len(context)):
    # print(questions[i], context[i])
    generated_answers.append(generate_answer("What is the next action?", context[i]))
    print(generated_answers[i], labels[i])
print(generated_answers)
print(labels)

palpate landmark palpate landmark
take swab take swab
prep site prep site
drop swab drop swab
take syringe take syringe
inject lidocaine inject lidocaine
drop syringe drop syringe
take scalpel take scalpel
incise skin incise skin
insert finger insert finger
incise membrane incise membrane
drop scalpel drop scalpel
take hook take hook
insert hook remove finger
insert hook insert hook
take tube take tube
insert tube insert tube
remove hook remove hook
drop hook drop hook
detach syringe inflate cuff
drop syringe detach syringe
drop syringe drop syringe
drop swab palpate landmark
take swab take swab
prep site prep site
drop swab drop swab
take syringe take syringe
take syringe palpate landmark
inject lidocaine inject lidocaine
drop syringe drop syringe
take scalpel take scalpel
take scalpel palpate landmark
incise skin incise skin
incise skin incise skin
insert finger insert finger
insert finger remove finger
incise skin incise membrane
incise membrane insert finger
drop scalpel drop scalp

In [19]:
correct = 0
total = 0
for i in range(len(generated_answers)):
    if((generated_answers[i])==(labels[i])):
        correct = correct+1
        print(correct, total)
    total = total+1
print(correct/total)

1 0
2 1
3 2
4 3
5 4
6 5
7 6
8 7
9 8
10 9
11 10
12 11
13 12
14 14
15 15
16 16
17 17
18 18
19 21
20 23
21 24
22 25
23 26
24 28
25 29
26 30
27 32
28 33
29 34
30 38
31 39
32 41
33 42
34 43
35 44
36 45
37 46
38 47
39 48
40 49
41 50
42 51
43 52
44 53
45 54
46 55
47 56
48 57
49 58
50 61
51 63
52 64
53 66
54 67
55 69
56 70
57 71
58 72
59 73
60 74
61 75
62 76
63 77
64 78
65 79
66 80
67 81
68 82
69 83
70 84
71 85
72 86
73 87
74 88
75 89
76 90
77 91
78 92
79 93
80 94
81 95
82 96
83 97
84 98
85 99
86 100
87 101
88 102
89 103
90 104
91 105
92 106
93 107
94 108
95 109
96 110
97 111
98 112
99 113
100 114
101 116
102 117
103 118
104 119
105 120
106 121
107 122
108 123
109 124
110 125
111 128
112 129
113 130
114 132
115 133
116 134
117 135
118 136
119 138
120 139
121 140
122 142
123 144
124 147
125 155
126 157
127 160
128 163
129 169
130 175
131 177
132 178
133 179
134 181
135 183
136 184
137 185
138 186
139 187
140 189
141 191
142 193
143 194
144 195
145 196
146 198
147 203
148 205
149 211
150 212
151