# NLP Multilabel Classification Model for Herbal Classification Based on Symptoms or Herbal Attributes

In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [3]:
train_path = 'dataset.csv'
train_csv = pd.read_csv(train_path)
train_csv.head()
print(train_csv.shape)

(1360, 12)


In [4]:
train_csv['SYMPTOMS'].str.split().map(lambda x: len(x)).max()

11

In [5]:
maxLenSymp = train_csv['SYMPTOMS'].str.split().map(lambda x: len(x)).max()
extra_tokens = 20

In [6]:
tokenizer = AutoTokenizer.from_pretrained('medicalai/ClinicalBERT')

In [7]:
symptom_tokens = tokenizer(train_csv['SYMPTOMS'].values.tolist(),
                           padding=True,
                           truncation=True,
                           max_length=maxLenSymp + extra_tokens,
                           return_tensors='pt')['input_ids']

In [8]:
class ClassificationDataset(Dataset):
    def __init__(self, csv_path, symptom_tokens, low_limit, high_limit):
        self.data = pd.read_csv(csv_path).iloc[low_limit:high_limit].reset_index(drop=True)
        self.symptom_tokens = symptom_tokens[low_limit:high_limit]

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (
            self.symptom_tokens[idx],
            self.data[self.data.columns[2]].iloc[idx],
            self.data[self.data.columns[3]].iloc[idx],
            self.data[self.data.columns[4]].iloc[idx],
            self.data[self.data.columns[5]].iloc[idx],
            self.data[self.data.columns[6]].iloc[idx],
            self.data[self.data.columns[7]].iloc[idx],
            self.data[self.data.columns[8]].iloc[idx],
            self.data[self.data.columns[9]].iloc[idx],
            self.data[self.data.columns[10]].iloc[idx],
            self.data[self.data.columns[11]].iloc[idx]
        )

In [9]:
train_data = ClassificationDataset(train_path, symptom_tokens, 0, int(len(train_csv) * 0.8))
test_data = ClassificationDataset(train_path, symptom_tokens, int(len(train_csv) * 0.8), int(len(train_csv)))

In [10]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=32, drop_last=True)
test_dataLoader = DataLoader(test_data, shuffle=True, batch_size=32, drop_last=True)

In [11]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.embedding = AutoModel.from_pretrained('medicalai/ClinicalBERT')
        self.embedding.config.pad_token_id
        for param in self.embedding.parameters():
            param.requires_grad = False
        
        # self.rnn_symptom = nn.LSTM(input_size=768, hidden_size=64, bidirectional=True, batch_first=True)
        self.rnn_symptom = nn.LSTM(input_size=768, hidden_size=64, bidirectional=True, batch_first=True)
        
        f_in = 128

        self.jackfruit = nn.Linear(in_features=f_in, out_features=1)
        self.sambong = nn.Linear(in_features=f_in, out_features=1)
        self.lemon = nn.Linear(in_features=f_in, out_features=1)
        self.jasmine = nn.Linear(in_features=f_in, out_features=1)
        self.mango = nn.Linear(in_features=f_in, out_features=1)
        self.mint = nn.Linear(in_features=f_in, out_features=1)
        self.ampalaya = nn.Linear(in_features=f_in, out_features=1)
        self.malunggay = nn.Linear(in_features=f_in, out_features=1)
        self.guava = nn.Linear(in_features=f_in, out_features=1)
        self.lagundi = nn.Linear(in_features=f_in, out_features=1)

        self.act = nn.Sigmoid()

    def forward(self, symptom):
        symptom_embedding = self.embedding(symptom).last_hidden_state
        symptom_features = self.rnn_symptom(symptom_embedding)[0][:, -1, :]
        
        
        jackfruit_classify = self.act(self.jackfruit(symptom_features))
        sambong_classify = self.act(self.sambong(symptom_features))
        lemon_classify = self.act(self.lemon(symptom_features))
        jasmine_classify = self.act(self.jasmine(symptom_features))
        mango_classify = self.act(self.mango(symptom_features))
        mint_classify = self.act(self.mint(symptom_features))
        ampalaya_classify = self.act(self.ampalaya(symptom_features))
        malunggay_classify = self.act(self.malunggay(symptom_features))
        guava_classify = self.act(self.guava(symptom_features))
        lagundi_classify = self.act(self.lagundi(symptom_features))

        return (
            jackfruit_classify,
            sambong_classify,
            lemon_classify,
            jasmine_classify,
            mango_classify,
            mint_classify,
            ampalaya_classify,
            malunggay_classify,
            guava_classify,
            lagundi_classify
        )


In [12]:
model = Classifier().to(device)

  return self.fget.__get__(instance, owner)()


In [13]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [14]:
epochs = 10

for epoch in range(epochs):
    train_loss, train_loss_jackfruit, train_loss_sambong, train_loss_lemon, train_loss_jasmine, train_loss_mango, train_loss_mint, train_loss_ampalaya, train_loss_malunggay, train_loss_guava, train_loss_lagundi = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

    train_acc = 0
    train_acc_jackfruit = 0
    train_acc_sambong = 0
    train_acc_lemon = 0
    train_acc_jasmine = 0
    train_acc_mango = 0
    train_acc_mint = 0
    train_acc_ampalaya = 0
    train_acc_malunggay = 0
    train_acc_guava = 0
    train_acc_lagundi = 0

    test_loss = 0
    test_loss_jackfruit = 0
    test_loss_sambong = 0
    test_loss_lemon = 0
    test_loss_jasmine = 0
    test_loss_mango = 0
    test_loss_mint = 0
    test_loss_ampalaya = 0
    test_loss_malunggay = 0
    test_loss_guava = 0
    test_loss_lagundi = 0

    test_acc = 0
    test_acc_jackfruit = 0
    test_acc_sambong = 0
    test_acc_lemon = 0
    test_acc_jasmine = 0
    test_acc_mango = 0
    test_acc_mint = 0
    test_acc_ampalaya = 0
    test_acc_malunggay = 0
    test_acc_guava = 0
    test_acc_lagundi = 0

    model.train()

    for (index, (symptom, jackfruit_true, sambong_true, lemon_true, jasmine_true, mango_true, mint_true, ampalaya_true, malunggay_true, guava_true, lagundi_true)) in enumerate(train_dataloader):
        optimizer.zero_grad()
        symptom, jackfruit_true, sambong_true, lemon_true, jasmine_true, mango_true, mint_true, ampalaya_true, malunggay_true, guava_true, lagundi_true = symptom.to(device), jackfruit_true.to(device), sambong_true.to(device), lemon_true.to(device), jasmine_true.to(device), mango_true.to(device), mint_true.to(device), ampalaya_true.to(device), malunggay_true.to(device), guava_true.to(device), lagundi_true.to(device)
        
        jackfruit_classify, sambong_classify, lemon_classify, jasmine_classify, mango_classify, mint_classify, ampalaya_classify, malunggay_classify, guava_classify, lagundi_classify = model(symptom)

        loss_jackfruit = loss_fn(jackfruit_classify, jackfruit_true.float().unsqueeze(dim=1))
        loss_sambong = loss_fn(sambong_classify, sambong_true.float().unsqueeze(dim=1))
        loss_lemon = loss_fn(lemon_classify, lemon_true.float().unsqueeze(dim=1))
        loss_jasmine = loss_fn(jasmine_classify, jasmine_true.float().unsqueeze(dim=1))
        loss_mango = loss_fn(mango_classify, mango_true.float().unsqueeze(dim=1))
        loss_mint = loss_fn(mint_classify, mint_true.float().unsqueeze(dim=1))
        loss_ampalaya = loss_fn(ampalaya_classify, ampalaya_true.float().unsqueeze(dim=1))
        loss_malunggay = loss_fn(malunggay_classify, malunggay_true.float().unsqueeze(dim=1))
        loss_guava = loss_fn(guava_classify, guava_true.float().unsqueeze(dim=1))
        loss_lagundi = loss_fn(lagundi_classify, lagundi_true.float().unsqueeze(dim=1))

        loss = loss_jackfruit + loss_sambong + loss_lemon + loss_jasmine + loss_mango + loss_mint + loss_ampalaya + loss_malunggay + loss_guava + loss_lagundi

        train_loss_jackfruit += loss_jackfruit.item()
        train_loss_sambong += loss_sambong.item()
        train_loss_lemon += loss_lemon.item()
        train_loss_jasmine += loss_jasmine.item()
        train_loss_mango += loss_mango.item()
        train_loss_mint += loss_mint.item()
        train_loss_ampalaya += loss_ampalaya.item()
        train_loss_malunggay += loss_malunggay.item()
        train_loss_guava += loss_guava.item()
        train_loss_lagundi += loss_lagundi.item()

        loss.backward()

        optimizer.step()

        acc_jackfruit = (jackfruit_classify.round() == jackfruit_true.float().unsqueeze(dim=1)).sum().item() / jackfruit_classify.size(0)
        acc_sambong = (sambong_classify.round() == sambong_true.float().unsqueeze(dim=1)).sum().item() / sambong_classify.size(0)
        acc_lemon = (lemon_classify.round() == lemon_true.float().unsqueeze(dim=1)).sum().item() / lemon_classify.size(0)
        acc_jasmine = (jasmine_classify.round() == jasmine_true.float().unsqueeze(dim=1)).sum().item() / jasmine_classify.size(0)
        acc_mango = (mango_classify.round() == mango_true.float().unsqueeze(dim=1)).sum().item() / mango_classify.size(0)
        acc_mint = (mint_classify.round() == mint_true.float().unsqueeze(dim=1)).sum().item() / mint_classify.size(0)
        acc_ampalaya = (ampalaya_classify.round() == ampalaya_true.float().unsqueeze(dim=1)).sum().item() / ampalaya_classify.size(0)
        acc_malunggay = (malunggay_classify.round() == malunggay_true.float().unsqueeze(dim=1)).sum().item() / malunggay_classify.size(0)
        acc_guava = (guava_classify.round() == guava_true.float().unsqueeze(dim=1)).sum().item() / guava_classify.size(0)
        acc_lagundi = (lagundi_classify.round() == lagundi_true.float().unsqueeze(dim=1)).sum().item() / lagundi_classify.size(0)

        acc = (acc_jackfruit + acc_sambong + acc_lemon + acc_jasmine + acc_mango + acc_mint + acc_ampalaya + acc_malunggay + acc_guava + acc_lagundi) / 10

        train_acc += acc
        train_acc_jackfruit += acc_jackfruit
        train_acc_sambong += acc_sambong
        train_acc_lemon += acc_lemon
        train_acc_jasmine += acc_jasmine
        train_acc_mango += acc_mango
        train_acc_mint += acc_mint
        train_acc_ampalaya += acc_ampalaya
        train_acc_malunggay += acc_malunggay
        train_acc_guava += acc_guava
        train_acc_lagundi += acc_lagundi

        print("Epoch {}: Batch: {}/{} || Loss Jackfruit: {:.4f} || Loss Sambong: {:.4f} || Loss Lemon: {:.4f} || Loss Jasmine: {:.4f} || Loss Mango: {:.4f} || Loss Mint: {:.4f} || Loss Ampalaya: {:.4f} || Loss Malunggay: {:.4f} || Loss Guava: {:.4f} || Loss Lagundi: {:.4f} || Loss Total: {:.4f} || Acc Jackfruit: {:.4f} || Acc Sambong: {:.4f} || Acc Lemon: {:.4f} || Acc Jasmine: {:.4f} || Acc Mango: {:.4f} || Acc Mint: {:.4f} || Acc Ampalaya: {:.4f} || Acc Malunggay: {:.4f} || Acc Guava: {:.4f} || Acc Lagundi: {:.4f} || Acc Total: {:.4f}".format(
            epoch+1,
            index,
            len(train_dataloader),
            loss_jackfruit,
            loss_sambong,
            loss_lemon,
            loss_jasmine,
            loss_mango,
            loss_mint,
            loss_ampalaya,
            loss_malunggay,
            loss_guava,
            loss_lagundi,
            loss,
            acc_jackfruit,
            acc_sambong,
            acc_lemon,
            acc_jasmine,
            acc_mango,
            acc_mint,
            acc_ampalaya,
            acc_malunggay,
            acc_guava,
            acc_lagundi,
            acc
        ), end='\r')

    print("Epoch Train {}: Loss Jackfruit: {:.4f} || Loss Sambong: {:.4f} || Loss Lemon: {:.4f} || Loss Jasmine: {:.4f} || Loss Mango: {:.4f} || Loss Mint: {:.4f} || Loss Ampalaya: {:.4f} || Loss Malunggay: {:.4f} || Loss Guava: {:.4f} || Loss Lagundi: {:.4f} || Loss Total: {:.4f} || Acc Jackfruit: {:.4f} || Acc Sambong: {:.4f} || Acc Lemon: {:.4f} || Acc Jasmine: {:.4f} || Acc Mango: {:.4f} || Acc Mint: {:.4f} || Acc Ampalaya: {:.4f} || Acc Malunggay: {:.4f} || Acc Guava: {:.4f} || Acc Lagundi: {:.4f} || Acc Total: {:.4f}".format(
        epoch+1,
        train_loss_jackfruit / len(train_dataloader),
        train_loss_sambong / len(train_dataloader),
        train_loss_lemon / len(train_dataloader),
        train_loss_jasmine / len(train_dataloader),
        train_loss_mango / len(train_dataloader),
        train_loss_mint / len(train_dataloader),
        train_loss_ampalaya / len(train_dataloader),
        train_loss_malunggay / len(train_dataloader),
        train_loss_guava / len(train_dataloader),
        train_loss_lagundi / len(train_dataloader),
        train_loss / len(train_dataloader),
        train_acc_jackfruit / len(train_dataloader),
        train_acc_sambong / len(train_dataloader),
        train_acc_lemon / len(train_dataloader),
        train_acc_jasmine / len(train_dataloader),
        train_acc_mango / len(train_dataloader),
        train_acc_mint / len(train_dataloader),
        train_acc_ampalaya / len(train_dataloader),
        train_acc_malunggay / len(train_dataloader),
        train_acc_guava / len(train_dataloader),
        train_acc_lagundi / len(train_dataloader),
        train_acc / len(train_dataloader)
    ))

    model.eval()

    for symptom, jackfruit_true, sambong_true, lemon_true, jasmine_true, mango_true, mint_true, ampalaya_true, malunggay_true, guava_true, lagundi_true in tqdm(test_dataLoader):
        with torch.no_grad():
            symptom, jackfruit_true, sambong_true, lemon_true, jasmine_true, mango_true, mint_true, ampalaya_true, malunggay_true, guava_true, lagundi_true = symptom.to(device), jackfruit_true.to(device), sambong_true.to(device), lemon_true.to(device), jasmine_true.to(device), mango_true.to(device), mint_true.to(device), ampalaya_true.to(device), malunggay_true.to(device), guava_true.to(device), lagundi_true.to(device)

            jackfruit_classify, sambong_classify, lemon_classify, jasmine_classify, mango_classify, mint_classify, ampalaya_classify, malunggay_classify, guava_classify, lagundi_classify = model(symptom)

            loss_jackfruit = loss_fn(jackfruit_classify, jackfruit_true.float().unsqueeze(dim=1))
            loss_sambong = loss_fn(sambong_classify, sambong_true.float().unsqueeze(dim=1))
            loss_lemon = loss_fn(lemon_classify, lemon_true.float().unsqueeze(dim=1))
            loss_jasmine = loss_fn(jasmine_classify, jasmine_true.float().unsqueeze(dim=1))
            loss_mango = loss_fn(mango_classify, mango_true.float().unsqueeze(dim=1))
            loss_mint = loss_fn(mint_classify, mint_true.float().unsqueeze(dim=1))
            loss_ampalaya = loss_fn(ampalaya_classify, ampalaya_true.float().unsqueeze(dim=1))
            loss_malunggay = loss_fn(malunggay_classify, malunggay_true.float().unsqueeze(dim=1))
            loss_guava = loss_fn(guava_classify, guava_true.float().unsqueeze(dim=1))
            loss_lagundi = loss_fn(lagundi_classify, lagundi_true.float().unsqueeze(dim=1))

            loss = loss_jackfruit + loss_sambong + loss_lemon + loss_jasmine + loss_mango + loss_mint + loss_ampalaya + loss_malunggay + loss_guava + loss_lagundi
        
        test_loss_jackfruit += loss_jackfruit.item()
        test_loss_sambong += loss_sambong.item()
        test_loss_lemon += loss_lemon.item()
        test_loss_jasmine += loss_jasmine.item()
        test_loss_mango += loss_mango.item()
        test_loss_mint += loss_mint.item()
        test_loss_ampalaya += loss_ampalaya.item()
        test_loss_malunggay += loss_malunggay.item()
        test_loss_guava += loss_guava.item()
        test_loss_lagundi += loss_lagundi.item()

        acc_jackfruit = (jackfruit_classify.round() == jackfruit_true.float().unsqueeze(dim=1)).sum().item() / jackfruit_classify.size(0)
        acc_sambong = (sambong_classify.round() == sambong_true.float().unsqueeze(dim=1)).sum().item() / sambong_classify.size(0)
        acc_lemon = (lemon_classify.round() == lemon_true.float().unsqueeze(dim=1)).sum().item() / lemon_classify.size(0)
        acc_jasmine = (jasmine_classify.round() == jasmine_true.float().unsqueeze(dim=1)).sum().item() / jasmine_classify.size(0)
        acc_mango = (mango_classify.round() == mango_true.float().unsqueeze(dim=1)).sum().item() / mango_classify.size(0)
        acc_mint = (mint_classify.round() == mint_true.float().unsqueeze(dim=1)).sum().item() / mint_classify.size(0)
        acc_ampalaya = (ampalaya_classify.round() == ampalaya_true.float().unsqueeze(dim=1)).sum().item() / ampalaya_classify.size(0)
        acc_malunggay = (malunggay_classify.round() == malunggay_true.float().unsqueeze(dim=1)).sum().item() / malunggay_classify.size(0)
        acc_guava = (guava_classify.round() == guava_true.float().unsqueeze(dim=1)).sum().item() / guava_classify.size(0)
        acc_lagundi = (lagundi_classify.round() == lagundi_true.float().unsqueeze(dim=1)).sum().item() / lagundi_classify.size(0)

        acc = (acc_jackfruit + acc_sambong + acc_lemon + acc_jasmine + acc_mango + acc_mint + acc_ampalaya + acc_malunggay + acc_guava + acc_lagundi) / 10

        test_acc += acc
        test_acc_jackfruit += acc_jackfruit
        test_acc_sambong += acc_sambong
        test_acc_lemon += acc_lemon
        test_acc_jasmine += acc_jasmine
        test_acc_mango += acc_mango
        test_acc_mint += acc_mint
        test_acc_ampalaya += acc_ampalaya
        test_acc_malunggay += acc_malunggay
        test_acc_guava += acc_guava
        test_acc_lagundi += acc_lagundi
        
    print("Epoch Test {}: Loss Jackfruit: {:.4f} || Loss Sambong: {:.4f} || Loss Lemon: {:.4f} || Loss Jasmine: {:.4f} || Loss Mango: {:.4f} || Loss Mint: {:.4f} || Loss Ampalaya: {:.4f} || Loss Malunggay: {:.4f} || Loss Guava: {:.4f} || Loss Lagundi: {:.4f} || Loss Total: {:.4f} || Acc Jackfruit: {:.4f} || Acc Sambong: {:.4f} || Acc Lemon: {:.4f} || Acc Jasmine: {:.4f} || Acc Mango: {:.4f} || Acc Mint: {:.4f} || Acc Ampalaya: {:.4f} || Acc Malunggay: {:.4f} || Acc Guava: {:.4f} || Acc Lagundi: {:.4f} || Acc Total: {:.4f}".format(
        epoch+1,
        test_loss_jackfruit / len(test_dataLoader),
        test_loss_sambong / len(test_dataLoader),
        test_loss_lemon / len(test_dataLoader),
        test_loss_jasmine / len(test_dataLoader),
        test_loss_mango / len(test_dataLoader),
        test_loss_mint / len(test_dataLoader),
        test_loss_ampalaya / len(test_dataLoader),
        test_loss_malunggay / len(test_dataLoader),
        test_loss_guava / len(test_dataLoader),
        test_loss_lagundi / len(test_dataLoader),
        test_loss / len(test_dataLoader),
        test_acc_jackfruit / len(test_dataLoader),
        test_acc_sambong / len(test_dataLoader),
        test_acc_lemon / len(test_dataLoader),
        test_acc_jasmine / len(test_dataLoader),
        test_acc_mango / len(test_dataLoader),
        test_acc_mint / len(test_dataLoader),
        test_acc_ampalaya / len(test_dataLoader),
        test_acc_malunggay / len(test_dataLoader),
        test_acc_guava / len(test_dataLoader),
        test_acc_lagundi / len(test_dataLoader),
        test_acc / len(test_dataLoader)
    ))

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch Train 1: Loss Jackfruit: 0.4360 || Loss Sambong: 0.6798 || Loss Lemon: 0.4541 || Loss Jasmine: 0.4610 || Loss Mango: 0.4751 || Loss Mint: 0.4946 || Loss Ampalaya: 0.3794 || Loss Malunggay: 0.3979 || Loss Guava: 0.5399 || Loss Lagundi: 0.4726 || Loss Total: 0.0000 || Acc Jackfruit: 0.8640 || Acc Sambong: 0.5478 || Acc Lemon: 0.8401 || Acc Jasmine: 0.8290 || Acc Mango: 0.8116 || Acc Mint: 0.8070 || Acc Ampalaya: 0.8759 || Acc Malunggay: 0.8851 || Acc Guava: 0.7849 || Acc Lagundi: 0.8336 || Acc Total: 0.8079al: 0.8187


100%|██████████| 8/8 [00:04<00:00,  1.60it/s]


Epoch Test 1: Loss Jackfruit: 0.3700 || Loss Sambong: 0.6662 || Loss Lemon: 0.2935 || Loss Jasmine: 0.3962 || Loss Mango: 0.4564 || Loss Mint: 0.4771 || Loss Ampalaya: 0.2974 || Loss Malunggay: 0.3744 || Loss Guava: 0.4044 || Loss Lagundi: 0.3736 || Loss Total: 0.0000 || Acc Jackfruit: 0.8789 || Acc Sambong: 0.5977 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8633 || Acc Mango: 0.8320 || Acc Mint: 0.8164 || Acc Ampalaya: 0.9141 || Acc Malunggay: 0.8828 || Acc Guava: 0.8711 || Acc Lagundi: 0.8750 || Acc Total: 0.8445
Epoch Train 2: Loss Jackfruit: 0.3741 || Loss Sambong: 0.6439 || Loss Lemon: 0.3784 || Loss Jasmine: 0.4335 || Loss Mango: 0.4217 || Loss Mint: 0.4248 || Loss Ampalaya: 0.2953 || Loss Malunggay: 0.3445 || Loss Guava: 0.4962 || Loss Lagundi: 0.4120 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.6324 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8474 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.89it/s]


Epoch Test 2: Loss Jackfruit: 0.3601 || Loss Sambong: 0.6345 || Loss Lemon: 0.3107 || Loss Jasmine: 0.3883 || Loss Mango: 0.4519 || Loss Mint: 0.4714 || Loss Ampalaya: 0.3003 || Loss Malunggay: 0.3702 || Loss Guava: 0.3876 || Loss Lagundi: 0.3529 || Loss Total: 0.0000 || Acc Jackfruit: 0.8828 || Acc Sambong: 0.6641 || Acc Lemon: 0.9062 || Acc Jasmine: 0.8672 || Acc Mango: 0.8320 || Acc Mint: 0.8203 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8789 || Acc Guava: 0.8750 || Acc Lagundi: 0.8867 || Acc Total: 0.8520
Epoch Train 3: Loss Jackfruit: 0.3692 || Loss Sambong: 0.6168 || Loss Lemon: 0.3722 || Loss Jasmine: 0.4307 || Loss Mango: 0.4167 || Loss Mint: 0.4177 || Loss Ampalaya: 0.2861 || Loss Malunggay: 0.3318 || Loss Guava: 0.4940 || Loss Lagundi: 0.3981 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.6756 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8474 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.65it/s]


Epoch Test 3: Loss Jackfruit: 0.3357 || Loss Sambong: 0.5794 || Loss Lemon: 0.2866 || Loss Jasmine: 0.3699 || Loss Mango: 0.4569 || Loss Mint: 0.4646 || Loss Ampalaya: 0.2925 || Loss Malunggay: 0.3504 || Loss Guava: 0.3690 || Loss Lagundi: 0.3527 || Loss Total: 0.0000 || Acc Jackfruit: 0.8906 || Acc Sambong: 0.7188 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8711 || Acc Mango: 0.8281 || Acc Mint: 0.8086 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8789 || Acc Guava: 0.8789 || Acc Lagundi: 0.8750 || Acc Total: 0.8570
Epoch Train 4: Loss Jackfruit: 0.3629 || Loss Sambong: 0.5703 || Loss Lemon: 0.3641 || Loss Jasmine: 0.4218 || Loss Mango: 0.4072 || Loss Mint: 0.3981 || Loss Ampalaya: 0.2746 || Loss Malunggay: 0.3247 || Loss Guava: 0.4899 || Loss Lagundi: 0.3857 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.6949 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8474 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:05<00:00,  1.56it/s]


Epoch Test 4: Loss Jackfruit: 0.3479 || Loss Sambong: 0.5600 || Loss Lemon: 0.3036 || Loss Jasmine: 0.3746 || Loss Mango: 0.4160 || Loss Mint: 0.4233 || Loss Ampalaya: 0.2724 || Loss Malunggay: 0.3360 || Loss Guava: 0.3946 || Loss Lagundi: 0.3249 || Loss Total: 0.0000 || Acc Jackfruit: 0.8828 || Acc Sambong: 0.6680 || Acc Lemon: 0.9062 || Acc Jasmine: 0.8672 || Acc Mango: 0.8359 || Acc Mint: 0.8203 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8789 || Acc Guava: 0.8750 || Acc Lagundi: 0.8828 || Acc Total: 0.8523
Epoch Train 5: Loss Jackfruit: 0.3567 || Loss Sambong: 0.5333 || Loss Lemon: 0.3522 || Loss Jasmine: 0.4123 || Loss Mango: 0.3889 || Loss Mint: 0.3867 || Loss Ampalaya: 0.2663 || Loss Malunggay: 0.3037 || Loss Guava: 0.4782 || Loss Lagundi: 0.3678 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.7178 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8474 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.65it/s]


Epoch Test 5: Loss Jackfruit: 0.3485 || Loss Sambong: 0.4982 || Loss Lemon: 0.2829 || Loss Jasmine: 0.3600 || Loss Mango: 0.3902 || Loss Mint: 0.4292 || Loss Ampalaya: 0.2543 || Loss Malunggay: 0.3177 || Loss Guava: 0.3817 || Loss Lagundi: 0.3232 || Loss Total: 0.0000 || Acc Jackfruit: 0.8828 || Acc Sambong: 0.7617 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8672 || Acc Mango: 0.8438 || Acc Mint: 0.8086 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8750 || Acc Guava: 0.8906 || Acc Lagundi: 0.8789 || Acc Total: 0.8629
Epoch Train 6: Loss Jackfruit: 0.3501 || Loss Sambong: 0.4898 || Loss Lemon: 0.3386 || Loss Jasmine: 0.4068 || Loss Mango: 0.3672 || Loss Mint: 0.3673 || Loss Ampalaya: 0.2503 || Loss Malunggay: 0.2940 || Loss Guava: 0.4743 || Loss Lagundi: 0.3477 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.7638 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8474 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.75it/s]


Epoch Test 6: Loss Jackfruit: 0.3469 || Loss Sambong: 0.4642 || Loss Lemon: 0.2676 || Loss Jasmine: 0.3569 || Loss Mango: 0.3724 || Loss Mint: 0.3898 || Loss Ampalaya: 0.2415 || Loss Malunggay: 0.2987 || Loss Guava: 0.3561 || Loss Lagundi: 0.2923 || Loss Total: 0.0000 || Acc Jackfruit: 0.8789 || Acc Sambong: 0.7891 || Acc Lemon: 0.9180 || Acc Jasmine: 0.8711 || Acc Mango: 0.8359 || Acc Mint: 0.8125 || Acc Ampalaya: 0.9102 || Acc Malunggay: 0.8828 || Acc Guava: 0.8750 || Acc Lagundi: 0.8789 || Acc Total: 0.8652
Epoch Train 7: Loss Jackfruit: 0.3331 || Loss Sambong: 0.4397 || Loss Lemon: 0.3365 || Loss Jasmine: 0.3950 || Loss Mango: 0.3525 || Loss Mint: 0.3379 || Loss Ampalaya: 0.2426 || Loss Malunggay: 0.2872 || Loss Guava: 0.4634 || Loss Lagundi: 0.3283 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.7960 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8493 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.73it/s]


Epoch Test 7: Loss Jackfruit: 0.3345 || Loss Sambong: 0.4272 || Loss Lemon: 0.2789 || Loss Jasmine: 0.3389 || Loss Mango: 0.3493 || Loss Mint: 0.3444 || Loss Ampalaya: 0.2359 || Loss Malunggay: 0.2912 || Loss Guava: 0.3586 || Loss Lagundi: 0.2629 || Loss Total: 0.0000 || Acc Jackfruit: 0.8789 || Acc Sambong: 0.8086 || Acc Lemon: 0.9062 || Acc Jasmine: 0.8672 || Acc Mango: 0.8359 || Acc Mint: 0.8242 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8789 || Acc Guava: 0.8789 || Acc Lagundi: 0.9102 || Acc Total: 0.8695
Epoch Train 8: Loss Jackfruit: 0.3287 || Loss Sambong: 0.4085 || Loss Lemon: 0.3260 || Loss Jasmine: 0.3783 || Loss Mango: 0.3420 || Loss Mint: 0.3110 || Loss Ampalaya: 0.2443 || Loss Malunggay: 0.2790 || Loss Guava: 0.4504 || Loss Lagundi: 0.3167 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.8189 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8419 || Acc Mango: 0.8502 || Acc Mint: 0.8686 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8024 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.67it/s]


Epoch Test 8: Loss Jackfruit: 0.3285 || Loss Sambong: 0.3924 || Loss Lemon: 0.2678 || Loss Jasmine: 0.3241 || Loss Mango: 0.3378 || Loss Mint: 0.3269 || Loss Ampalaya: 0.2364 || Loss Malunggay: 0.3039 || Loss Guava: 0.3359 || Loss Lagundi: 0.2445 || Loss Total: 0.0000 || Acc Jackfruit: 0.8750 || Acc Sambong: 0.8438 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8672 || Acc Mango: 0.8359 || Acc Mint: 0.8320 || Acc Ampalaya: 0.9023 || Acc Malunggay: 0.8750 || Acc Guava: 0.8906 || Acc Lagundi: 0.9102 || Acc Total: 0.8746
Epoch Train 9: Loss Jackfruit: 0.3124 || Loss Sambong: 0.3557 || Loss Lemon: 0.3192 || Loss Jasmine: 0.3713 || Loss Mango: 0.3062 || Loss Mint: 0.2789 || Loss Ampalaya: 0.2360 || Loss Malunggay: 0.2778 || Loss Guava: 0.4434 || Loss Lagundi: 0.2842 || Loss Total: 0.0000 || Acc Jackfruit: 0.8768 || Acc Sambong: 0.8511 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8428 || Acc Mango: 0.8511 || Acc Mint: 0.8869 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8915 || Acc Guava: 0.8125 || Acc Lagu

100%|██████████| 8/8 [00:04<00:00,  1.80it/s]


Epoch Test 9: Loss Jackfruit: 0.3038 || Loss Sambong: 0.3422 || Loss Lemon: 0.2643 || Loss Jasmine: 0.3855 || Loss Mango: 0.3104 || Loss Mint: 0.3200 || Loss Ampalaya: 0.2074 || Loss Malunggay: 0.2811 || Loss Guava: 0.3605 || Loss Lagundi: 0.2404 || Loss Total: 0.0000 || Acc Jackfruit: 0.8789 || Acc Sambong: 0.8672 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8750 || Acc Mango: 0.8438 || Acc Mint: 0.8359 || Acc Ampalaya: 0.9141 || Acc Malunggay: 0.8789 || Acc Guava: 0.9023 || Acc Lagundi: 0.9219 || Acc Total: 0.8832
Epoch Train 10: Loss Jackfruit: 0.3002 || Loss Sambong: 0.3236 || Loss Lemon: 0.3122 || Loss Jasmine: 0.3456 || Loss Mango: 0.2962 || Loss Mint: 0.2484 || Loss Ampalaya: 0.2216 || Loss Malunggay: 0.2705 || Loss Guava: 0.4093 || Loss Lagundi: 0.2601 || Loss Total: 0.0000 || Acc Jackfruit: 0.8787 || Acc Sambong: 0.8768 || Acc Lemon: 0.8750 || Acc Jasmine: 0.8483 || Acc Mango: 0.8557 || Acc Mint: 0.8952 || Acc Ampalaya: 0.9127 || Acc Malunggay: 0.8925 || Acc Guava: 0.8300 || Acc Lag

100%|██████████| 8/8 [00:04<00:00,  1.67it/s]

Epoch Test 10: Loss Jackfruit: 0.2843 || Loss Sambong: 0.2624 || Loss Lemon: 0.2665 || Loss Jasmine: 0.3000 || Loss Mango: 0.2926 || Loss Mint: 0.2544 || Loss Ampalaya: 0.2257 || Loss Malunggay: 0.2915 || Loss Guava: 0.2993 || Loss Lagundi: 0.1985 || Loss Total: 0.0000 || Acc Jackfruit: 0.8789 || Acc Sambong: 0.9258 || Acc Lemon: 0.9141 || Acc Jasmine: 0.8750 || Acc Mango: 0.8398 || Acc Mint: 0.9023 || Acc Ampalaya: 0.9062 || Acc Malunggay: 0.8711 || Acc Guava: 0.9062 || Acc Lagundi: 0.9297 || Acc Total: 0.8949





In [19]:
#SAVE MODEL USING PICKLE
import pickle
model_pkl_file = "herbal_classifier_based_on_Symptoms"  

with open(model_pkl_file, 'wb') as file:  
    pickle.dump(model, file)

In [18]:
tryInput = "Can you please recommend me a herbal for headaches"
# ans_cols = ["JACKFRUIT", "SAMBONG", "LEMON", "JASMINE", "MANGO", "MINT", "AMPALAYA", "MALUNGGAY", "GUAVA", "LAGUNDI"]

# encoding = tokenizer(tryInput, return_tensors="pt")
# encoding = {k: v.to(model) for k, v in encoding.items()}

# outputs = model(**encoding)
# logits = outputs.logits

# sigmoid = torch.nn.Sigmoid()

# probs = sigmoid(logits.squeeze().cpu())
# predictions = np.zeros(probs.shape)
# predictions[np.where(probs >= 0.5)] = 1

# predicted_labels = [ans_cols[myIndex] for myIndex, label in enumerate(predictions) if label == 1.0]
# print(predicted_labels)

print(model(tryInput))

TypeError: string indices must be integers