In [None]:
!pip install datasets
!pip install peft
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes

Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.21.2 (from datasets)
  Downloading huggingface_hub-0.23.0-py3-none-a

In [None]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
class LateFusionEnsembler(nn.Module):
  def __init__(self,model1,model2,tokenizer,modelSize=32000,device="cuda"):
    super().__init__()
    self.model1 = model1.to(device)
    self.model2 = model2.to(device)
    self.tokenizer = tokenizer
    self.modelSize = modelSize
    self.device = device
    self.linear1 = nn.Linear(self.modelSize*2,4)
    # self.relu = nn.ReLU()
    # self.linear2 = nn.Linear(8000,4)

  def forward(self,inputIndices):
    y1 = self.model1(inputIndices).logits
    y2 = self.model2(inputIndices).logits

    n,h,w = y1.shape

    y1 = y1[:,h-1,:]
    y2 = y2[:,h-1,:]

    y= torch.cat((y1,y2),dim=1)

    y = self.linear1(y)
    # y = self.relu(y)
    # y = self.linear2(y)

    return y

In [None]:
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset,DatasetDict
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForSequenceClassification,
    MistralForSequenceClassification,
    PretrainedConfig,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
dataset_location = "/content/drive/MyDrive/685 Final Project/Datasets/medmcqa-prompts"

In [None]:
train_dataset = load_from_disk(f"{dataset_location}/train_prompts_micro.hf")
# test_dataset = load_from_disk(f"{dataset_location}/test_prompts_micro.hf")
eval_dataset = load_from_disk(f"{dataset_location}/eval_prompts_micro.hf")

# train_dataset = load_from_disk(f"{dataset_location}/train_prompts_mini.hf")
# test_dataset = load_from_disk(f"{dataset_location}/test_prompts_mini.hf")
# eval_dataset = load_from_disk(f"{dataset_location}/eval_prompts_mini.hf")

In [None]:
train_dataset

Dataset({
    features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name', 'prompt', 'label_one_hot'],
    num_rows: 2000
})

In [None]:
eval_dataset

Dataset({
    features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name', 'prompt', 'label_one_hot'],
    num_rows: 500
})

In [None]:
# Load pre-trained models
from unsloth import FastLanguageModel

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model_location = "/content/drive/MyDrive/685 Final Project/Models"

model1, tokenizer = FastLanguageModel.from_pretrained(model_location + "/unsloth_domain1",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

model2, tokenizer = FastLanguageModel.from_pretrained(model_location + "/ai2_arc_instruction_tuned_mistral_7b",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.1+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth 2024.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.
Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.564 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.1+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


In [None]:
from torch.utils.data import DataLoader, Dataset

class MCQDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)  # Changed to float for one-hot encoding
        return item

    def __len__(self):
        return len(self.labels)

# Function to encode the data
def encode_data(tokenizer, prompts):
    # encodings = tokenizer(prompts, truncation=True, padding=True, max_length=2048)
    encodings = tokenizer(prompts, truncation=True, padding=True)
    return encodings

# Prepare the data for tokenization
prompts = [item['prompt'] for item in train_dataset]
labels = [item['label_one_hot'] for item in train_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
train_set = MCQDataset(encodings, labels)

# DataLoader
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)


prompts = [item['prompt'] for item in eval_dataset]
labels = [item['label_one_hot'] for item in eval_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
eval_set = MCQDataset(encodings, labels)

# DataLoader
val_loader = DataLoader(eval_set, batch_size=32, shuffle=True)

In [None]:
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from torch.nn.functional import softmax
def train_and_validate(model, train_loader, val_loader, epochs=3):

    saved_model_location = "/content/drive/MyDrive/685 Final Project/Models"

    scaler = GradScaler()
    device = torch.device("cuda")
    model = model.to(device)  # Ensures model and all submodules are float32
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()

    # for epoch in tqdm(range(epochs)):

    for epoch in range(epochs):
        total_train_loss = 0
        total_train_correct = 0
        train_samples = 0
        # correct=list()
        model.train()
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]", unit="batch")
        for i, batch in enumerate(train_pbar):
            input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
            train_samples += labels.size(0)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(input_ids).float()
                loss = criterion(output, labels.float())
                predictions = torch.argmax(softmax(output,dim=1), dim=1)
                labels_indices = torch.argmax(labels, dim=1)

                train_correct = (predictions == labels_indices).sum().item()
                total_train_correct += train_correct
                # print("\nTotal Correct : ", train_correct)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()

            train_pbar.set_postfix(loss=loss.item(), temp_acc=100 * total_train_correct / train_samples)


            if i % 1000 == 0:
                print(i, loss.item())
                print(f"Temp accuracy: ", total_train_correct / train_samples * 100)

            # Releasing the memory
            del input_ids, labels, output, loss, predictions, labels_indices


        model_save_path = f"{saved_model_location}/LateFusion.pth"
        torch.save(model.state_dict(), model_save_path)
        print("model Saved at", model_save_path)


        avg_train_loss = total_train_loss / len(train_loader)
        train_accuracy = total_train_correct / train_samples * 100
        print(f"Training Accuracy: ", train_accuracy)
        print(f"Epoch {epoch+1}, Loss: {avg_train_loss}")

        model.eval()
        total_val_loss, val_samples, total_val_correct = 0, 0, 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                input_ids, labels = batch['input_ids'].to(device), batch['labels'].to(device)
                with torch.cuda.amp.autocast():
                    outputs = model(input_ids).float()
                    val_loss = criterion(outputs, labels.float())
                    predictions = torch.argmax(softmax(outputs,dim=1), dim=1)
                    labels_indices = torch.argmax(labels, dim=1)
                    total_val_correct += (predictions == labels_indices).sum().item()

                total_val_loss += val_loss.item()
                val_samples += labels.size(0)

        avg_val_loss = total_val_loss / len(val_loader)
        val_accuracy = total_val_correct / val_samples * 100
        print(f"Validation Accuracy: ", val_accuracy)
        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}")



In [None]:
torch.cuda.empty_cache()

In [None]:
lf = LateFusionEnsembler(model1,model2,tokenizer)

In [None]:
train_and_validate(lf,train_loader,val_loader,epochs=1)

Epoch 1 [TRAIN]:   0%|          | 0/63 [00:00<?, ?batch/s]


Total Correct :  9


Epoch 1 [TRAIN]:   2%|▏         | 1/63 [00:06<06:48,  6.58s/batch, loss=1.52, temp_acc=28.1]

0 1.5157479047775269
Temp accuracy:  28.125

Total Correct :  5


Epoch 1 [TRAIN]:   3%|▎         | 2/63 [00:11<05:41,  5.60s/batch, loss=1.77, temp_acc=21.9]


Total Correct :  12


Epoch 1 [TRAIN]:   5%|▍         | 3/63 [00:16<05:17,  5.29s/batch, loss=1.59, temp_acc=27.1]


Total Correct :  9


Epoch 1 [TRAIN]:   6%|▋         | 4/63 [00:21<05:03,  5.14s/batch, loss=1.59, temp_acc=27.3]


Total Correct :  9


Epoch 1 [TRAIN]:   8%|▊         | 5/63 [00:26<04:53,  5.06s/batch, loss=1.8, temp_acc=27.5]


Total Correct :  6


Epoch 1 [TRAIN]:  10%|▉         | 6/63 [00:31<04:45,  5.01s/batch, loss=1.82, temp_acc=26]


Total Correct :  6


Epoch 1 [TRAIN]:  11%|█         | 7/63 [00:36<04:38,  4.98s/batch, loss=1.8, temp_acc=25]


Total Correct :  11


Epoch 1 [TRAIN]:  13%|█▎        | 8/63 [00:41<04:32,  4.96s/batch, loss=1.51, temp_acc=26.2]


Total Correct :  8


Epoch 1 [TRAIN]:  14%|█▍        | 9/63 [00:45<04:27,  4.95s/batch, loss=1.7, temp_acc=26]


Total Correct :  6


Epoch 1 [TRAIN]:  16%|█▌        | 10/63 [00:50<04:21,  4.94s/batch, loss=1.82, temp_acc=25.3]


Total Correct :  8


Epoch 1 [TRAIN]:  17%|█▋        | 11/63 [00:55<04:16,  4.93s/batch, loss=1.56, temp_acc=25.3]


Total Correct :  9


Epoch 1 [TRAIN]:  19%|█▉        | 12/63 [01:00<04:11,  4.93s/batch, loss=1.57, temp_acc=25.5]


Total Correct :  9


Epoch 1 [TRAIN]:  21%|██        | 13/63 [01:05<04:06,  4.92s/batch, loss=1.65, temp_acc=25.7]


Total Correct :  5


Epoch 1 [TRAIN]:  22%|██▏       | 14/63 [01:10<04:01,  4.92s/batch, loss=1.78, temp_acc=25]


Total Correct :  5


Epoch 1 [TRAIN]:  24%|██▍       | 15/63 [01:15<03:56,  4.92s/batch, loss=1.83, temp_acc=24.4]


Total Correct :  9


Epoch 1 [TRAIN]:  25%|██▌       | 16/63 [01:20<03:51,  4.92s/batch, loss=1.63, temp_acc=24.6]


Total Correct :  7


Epoch 1 [TRAIN]:  27%|██▋       | 17/63 [01:25<03:46,  4.92s/batch, loss=1.65, temp_acc=24.4]


Total Correct :  8


Epoch 1 [TRAIN]:  29%|██▊       | 18/63 [01:30<03:41,  4.92s/batch, loss=1.6, temp_acc=24.5]


Total Correct :  12


Epoch 1 [TRAIN]:  30%|███       | 19/63 [01:35<03:36,  4.92s/batch, loss=1.71, temp_acc=25.2]


Total Correct :  5


Epoch 1 [TRAIN]:  32%|███▏      | 20/63 [01:40<03:31,  4.92s/batch, loss=1.78, temp_acc=24.7]


Total Correct :  9


Epoch 1 [TRAIN]:  33%|███▎      | 21/63 [01:44<03:26,  4.92s/batch, loss=1.76, temp_acc=24.9]


Total Correct :  9


Epoch 1 [TRAIN]:  35%|███▍      | 22/63 [01:49<03:21,  4.92s/batch, loss=1.7, temp_acc=25]


Total Correct :  6


Epoch 1 [TRAIN]:  37%|███▋      | 23/63 [01:54<03:16,  4.92s/batch, loss=1.93, temp_acc=24.7]


Total Correct :  15


Epoch 1 [TRAIN]:  38%|███▊      | 24/63 [01:59<03:11,  4.92s/batch, loss=1.48, temp_acc=25.7]


Total Correct :  7


Epoch 1 [TRAIN]:  40%|███▉      | 25/63 [02:04<03:06,  4.92s/batch, loss=1.79, temp_acc=25.5]


Total Correct :  5


Epoch 1 [TRAIN]:  41%|████▏     | 26/63 [02:09<03:06,  5.05s/batch, loss=2.01, temp_acc=25.1]


Total Correct :  14


Epoch 1 [TRAIN]:  43%|████▎     | 27/63 [02:14<03:00,  5.01s/batch, loss=7.88, temp_acc=25.8]


Total Correct :  8


Epoch 1 [TRAIN]:  44%|████▍     | 28/63 [02:19<02:54,  4.99s/batch, loss=11.2, temp_acc=25.8]


Total Correct :  7


Epoch 1 [TRAIN]:  46%|████▌     | 29/63 [02:24<02:48,  4.97s/batch, loss=20.2, temp_acc=25.6]


Total Correct :  10


Epoch 1 [TRAIN]:  48%|████▊     | 30/63 [02:29<02:43,  4.96s/batch, loss=10.5, temp_acc=25.8]


Total Correct :  8


Epoch 1 [TRAIN]:  49%|████▉     | 31/63 [02:34<02:38,  4.95s/batch, loss=5.58, temp_acc=25.8]


Total Correct :  9


Epoch 1 [TRAIN]:  51%|█████     | 32/63 [02:39<02:33,  4.95s/batch, loss=16, temp_acc=25.9]


Total Correct :  10


Epoch 1 [TRAIN]:  52%|█████▏    | 33/63 [02:44<02:28,  4.94s/batch, loss=9.36, temp_acc=26]


Total Correct :  8


Epoch 1 [TRAIN]:  54%|█████▍    | 34/63 [02:49<02:23,  4.94s/batch, loss=4.84, temp_acc=26]


Total Correct :  4


Epoch 1 [TRAIN]:  56%|█████▌    | 35/63 [02:54<02:18,  4.94s/batch, loss=4.19, temp_acc=25.6]


Total Correct :  4


Epoch 1 [TRAIN]:  57%|█████▋    | 36/63 [02:59<02:13,  4.94s/batch, loss=4.99, temp_acc=25.3]


Total Correct :  7


Epoch 1 [TRAIN]:  59%|█████▊    | 37/63 [03:04<02:08,  4.94s/batch, loss=2.74, temp_acc=25.2]


Total Correct :  11


Epoch 1 [TRAIN]:  60%|██████    | 38/63 [03:09<02:03,  4.94s/batch, loss=2.85, temp_acc=25.4]


Total Correct :  8


Epoch 1 [TRAIN]:  62%|██████▏   | 39/63 [03:14<01:58,  4.94s/batch, loss=2.94, temp_acc=25.4]


Total Correct :  11


Epoch 1 [TRAIN]:  63%|██████▎   | 40/63 [03:19<01:53,  4.94s/batch, loss=2.32, temp_acc=25.6]


Total Correct :  8


Epoch 1 [TRAIN]:  65%|██████▌   | 41/63 [03:23<01:48,  4.94s/batch, loss=3.2, temp_acc=25.6]


Total Correct :  8


Epoch 1 [TRAIN]:  67%|██████▋   | 42/63 [03:28<01:43,  4.94s/batch, loss=3.06, temp_acc=25.6]


Total Correct :  10


Epoch 1 [TRAIN]:  68%|██████▊   | 43/63 [03:33<01:38,  4.93s/batch, loss=2.05, temp_acc=25.7]


Total Correct :  9


Epoch 1 [TRAIN]:  70%|██████▉   | 44/63 [03:38<01:33,  4.93s/batch, loss=1.94, temp_acc=25.8]


Total Correct :  11


Epoch 1 [TRAIN]:  71%|███████▏  | 45/63 [03:43<01:28,  4.93s/batch, loss=1.79, temp_acc=26]


Total Correct :  8


Epoch 1 [TRAIN]:  73%|███████▎  | 46/63 [03:48<01:23,  4.93s/batch, loss=2.26, temp_acc=26]


Total Correct :  10


Epoch 1 [TRAIN]:  75%|███████▍  | 47/63 [03:53<01:18,  4.93s/batch, loss=1.89, temp_acc=26.1]


Total Correct :  8


Epoch 1 [TRAIN]:  76%|███████▌  | 48/63 [03:58<01:14,  4.93s/batch, loss=1.6, temp_acc=26]


Total Correct :  5


Epoch 1 [TRAIN]:  78%|███████▊  | 49/63 [04:03<01:09,  4.93s/batch, loss=1.68, temp_acc=25.8]


Total Correct :  5


Epoch 1 [TRAIN]:  79%|███████▉  | 50/63 [04:08<01:04,  4.93s/batch, loss=1.98, temp_acc=25.6]


Total Correct :  8


Epoch 1 [TRAIN]:  81%|████████  | 51/63 [04:13<00:59,  4.93s/batch, loss=1.64, temp_acc=25.6]


Total Correct :  9


Epoch 1 [TRAIN]:  83%|████████▎ | 52/63 [04:18<00:54,  4.93s/batch, loss=1.43, temp_acc=25.7]


Total Correct :  6


Epoch 1 [TRAIN]:  84%|████████▍ | 53/63 [04:23<00:49,  4.93s/batch, loss=1.52, temp_acc=25.5]


Total Correct :  9


Epoch 1 [TRAIN]:  86%|████████▌ | 54/63 [04:28<00:44,  4.93s/batch, loss=1.47, temp_acc=25.6]


Total Correct :  9


Epoch 1 [TRAIN]:  87%|████████▋ | 55/63 [04:33<00:39,  4.93s/batch, loss=1.71, temp_acc=25.6]


Total Correct :  10


Epoch 1 [TRAIN]:  89%|████████▉ | 56/63 [04:37<00:34,  4.93s/batch, loss=1.46, temp_acc=25.7]


Total Correct :  7


Epoch 1 [TRAIN]:  90%|█████████ | 57/63 [04:42<00:29,  4.93s/batch, loss=1.59, temp_acc=25.7]


Total Correct :  15


Epoch 1 [TRAIN]:  92%|█████████▏| 58/63 [04:47<00:24,  4.93s/batch, loss=1.39, temp_acc=26]


Total Correct :  8


Epoch 1 [TRAIN]:  94%|█████████▎| 59/63 [04:52<00:19,  4.93s/batch, loss=1.8, temp_acc=26]


Total Correct :  5


Epoch 1 [TRAIN]:  95%|█████████▌| 60/63 [04:57<00:14,  4.93s/batch, loss=1.73, temp_acc=25.8]


Total Correct :  8


Epoch 1 [TRAIN]:  97%|█████████▋| 61/63 [05:02<00:09,  4.93s/batch, loss=1.62, temp_acc=25.8]


Total Correct :  8


Epoch 1 [TRAIN]:  98%|█████████▊| 62/63 [05:07<00:04,  4.93s/batch, loss=1.66, temp_acc=25.8]


Total Correct :  3


Epoch 1 [TRAIN]: 100%|██████████| 63/63 [05:10<00:00,  4.93s/batch, loss=2.55, temp_acc=25.8]


model Saved at /content/drive/MyDrive/685 Final Project/Models/LateFusion0.pth
Training Accuracy:  25.75
Epoch 1, Loss: 3.063587164121961
Validation Accuracy:  26.400000000000002
Epoch 1 - Validation Loss: 2.0176


In [None]:
model_path = '/content/drive/MyDrive/685 Final Project/Models/LateFusion0.pth'
lf.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [None]:
train_and_validate(lf,train_loader,val_loader,epochs=2)

Epoch 1 [TRAIN]:   2%|▏         | 1/63 [00:04<05:07,  4.95s/batch, loss=2.09, temp_acc=28.1]

0 2.091259002685547
Temp accuracy:  28.125


Epoch 1 [TRAIN]: 100%|██████████| 63/63 [05:08<00:00,  4.90s/batch, loss=1.43, temp_acc=26.3]


model Saved at /content/drive/MyDrive/685 Final Project/Models/LateFusion0.pth
Training Accuracy:  26.3
Epoch 1, Loss: 2.6051125829181974
Validation Accuracy:  32.800000000000004
Epoch 1 - Validation Loss: 1.5131


Epoch 2 [TRAIN]:   2%|▏         | 1/63 [00:04<05:06,  4.94s/batch, loss=1.71, temp_acc=18.8]

0 1.7141356468200684
Temp accuracy:  18.75


Epoch 2 [TRAIN]: 100%|██████████| 63/63 [05:08<00:00,  4.90s/batch, loss=1.72, temp_acc=28.2]


model Saved at /content/drive/MyDrive/685 Final Project/Models/LateFusion1.pth
Training Accuracy:  28.199999999999996
Epoch 2, Loss: 1.670744844845363
Validation Accuracy:  26.6
Epoch 2 - Validation Loss: 1.9734


In [None]:
model_path = '/content/drive/MyDrive/685 Final Project/Models/LateFusion1.pth'
lf.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [None]:
train_and_validate(lf,train_loader,val_loader,epochs=2)

Epoch 1 [TRAIN]:   2%|▏         | 1/63 [00:06<06:24,  6.21s/batch, loss=1.48, temp_acc=43.8]

0 1.4807219505310059
Temp accuracy:  43.75


Epoch 1 [TRAIN]: 100%|██████████| 63/63 [05:10<00:00,  4.92s/batch, loss=1.39, temp_acc=26.2]


model Saved at /content/drive/MyDrive/685 Final Project/Models/LateFusion.pth
Training Accuracy:  26.25
Epoch 1, Loss: 2.369175250568087
Validation Accuracy:  21.2
Epoch 1 - Validation Loss: 1.3888


Epoch 2 [TRAIN]:   2%|▏         | 1/63 [00:04<05:06,  4.94s/batch, loss=1.47, temp_acc=21.9]

0 1.4694397449493408
Temp accuracy:  21.875


Epoch 2 [TRAIN]:   2%|▏         | 1/63 [00:09<10:00,  9.69s/batch, loss=1.47, temp_acc=21.9]


KeyboardInterrupt: 