In [None]:
!pip install transformers torch



In [None]:
%%writefile train.py

import torch
import numpy as np
import json
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

with open('intents.json', 'r') as f:
    intents = json.load(f)


tags = []
for intent in intents['intents']:
    tags.append(intent['tag'])
tags = sorted(set(tags))


input_ids = []
attention_masks = []
labels = []


for intent in intents['intents']:
    for pattern in intent['patterns']:
        encoded = tokenizer(
            pattern,
            add_special_tokens=True,
            max_length=20,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )


        input_ids.append(encoded['input_ids'][0])
        attention_masks.append(encoded['attention_mask'][0])

        label_ids = tags.index(intent['tag'])
        labels.append(label_ids)


input_ids = torch.stack(input_ids)
attention_masks = torch.stack(attention_masks)
labels = torch.tensor(labels)


class ChatDataset(Dataset):
    def __init__(self, encodings, mask, labels):

        self.encodings = encodings
        self.mask = mask
        self.labels = labels

    def __getitem__(self, idx):

        return {
            'input_ids': self.encodings[idx],
            'attention_mask': self.mask[idx],
            'labels': self.labels[idx]
        }

    def __len__(self):
        return len(self.labels)


dataset = ChatDataset(input_ids, attention_masks, labels)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)


class Bert_Arch(nn.Module):
    def __init__(self, output_dim):
        super(Bert_Arch, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')


        # for param in self.bert.parameters():
        #     param.requires_grad = False

        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, output_dim)

    def forward(self, sent_id, mask):
        output = self.bert(sent_id, attention_mask=mask)
        cls_vector = output.pooler_output
        x = self.fc(self.dropout(cls_vector))
        return x

output_dim = len(tags)
model = Bert_Arch(output_dim)
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
cross_entropy = nn.CrossEntropyLoss()



epochs = 50

for epoch in range(epochs):
    total_loss = 0

    for batch in train_loader:
        sent_id = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        model.zero_grad()


        preds = model(sent_id, mask)

        loss = cross_entropy(preds, labels)
        total_loss = total_loss + loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()


    avg_loss = total_loss / len(train_loader)

    if (epoch+1) % 20 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")


output_data = {
    "model_state": model.state_dict(),
    "output_dim": output_dim,
    "tags": tags,
    "vocab_size": len(tokenizer),
    "embed_dim": 768,
    "hidden_size": 768,
    "max_len": 20
}






torch.save(output_data, "bert_data.pth")










Overwriting train.py


In [None]:
!python train.py

2026-01-28 08:19:44.727355: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769588384.831449    2281 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769588384.864925    2281 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769588384.971348    2281 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769588384.971390    2281 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769588384.971398    2281 computation_placer.cc:177] computation placer alr