In [2]:
import json
import yaml
import torch
from dataset import HierarchicalInterviewDataset
from trainer import Trainer
from model import HierarchicalInterviewScorer
from torch.optim import AdamW
from torch.utils.data import DataLoader


In [3]:
import yaml

config_path = "config.yaml"
with open(config_path, "r") as file:
    config = yaml.safe_load(file)

with open(config["train"]["train_data_path"], "r") as f:
    train_data = json.load(f)

with open(config["train"]["val_data_path"], "r") as f:
    val_data = json.load(f)


In [4]:
train_dataset = HierarchicalInterviewDataset(train_data)
val_dataset = HierarchicalInterviewDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=config["train"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["train"]["batch_size"], shuffle=False)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HierarchicalInterviewScorer()
optimizer = AdamW(model.parameters(), lr=config["train"]["learning_rate"])




In [6]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    device=device,
    max_epochs=10
)
trainer.train()



100%|██████████| 14/14 [08:37<00:00, 36.98s/it]


Epoch 1 | Train Loss: 13.927306379590716
Validation Loss: 4.6899449825286865


100%|██████████| 14/14 [06:09<00:00, 26.40s/it]


Epoch 2 | Train Loss: 3.850368755204337
Validation Loss: 1.4700711965560913


100%|██████████| 14/14 [06:11<00:00, 26.51s/it]


Epoch 3 | Train Loss: 1.91862781558718
Validation Loss: 0.6817779690027237


100%|██████████| 14/14 [07:02<00:00, 30.17s/it]


Epoch 4 | Train Loss: 1.2395672457558768
Validation Loss: 0.45747625082731247


100%|██████████| 14/14 [10:40<00:00, 45.72s/it]


Epoch 5 | Train Loss: 0.9888021498918533
Validation Loss: 0.4803343415260315


100%|██████████| 14/14 [07:32<00:00, 32.30s/it]


Epoch 6 | Train Loss: 0.8625351105417524
Validation Loss: 0.5173141583800316


100%|██████████| 14/14 [03:51<00:00, 16.53s/it]


Epoch 7 | Train Loss: 0.7000675222703389
Validation Loss: 0.551289401948452


100%|██████████| 14/14 [03:54<00:00, 16.75s/it]


Epoch 8 | Train Loss: 0.7338389115674155
Validation Loss: 0.5217543169856071


100%|██████████| 14/14 [03:48<00:00, 16.29s/it]


Epoch 9 | Train Loss: 0.9143918859107154
Validation Loss: 0.5378134101629257


100%|██████████| 14/14 [09:27<00:00, 40.53s/it]


Epoch 10 | Train Loss: 0.7537154193435397
Validation Loss: 0.532896876335144


In [11]:
torch.save(model.state_dict(), "model/checkpoint.pth")

In [12]:
import torch
from model import HierarchicalInterviewScorer

model = HierarchicalInterviewScorer()
checkpoint_path = "model/checkpoint.pth"
model.load_state_dict(torch.load(checkpoint_path))
model.eval()

  model.load_state_dict(torch.load(checkpoint_path))


HierarchicalInterviewScorer(
  (turn_encoder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
      

In [20]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")

Total Parameters: 77588483
Trainable Parameters: 77588483
