In [3]:
import random
import json
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
import numpy as np
from nltk.stem.porter import PorterStemmer

disease_name = 'arthritis'

with open(f'{disease_name}.json', 'r') as json_data:
    intents = json.load(json_data)

data = torch.load(f"{disease_name}_data.pth")

stemmer = PorterStemmer()

def bow(tokenized, words):
    sentence_words = [stemmer.stem(word.lower()) for word in tokenized]
    bag = np.zeros(len(words), dtype = np.float32)
    for i, n in enumerate(words):
        if n in sentence_words:
            bag[i] = 1
    return bag

input_size = data["input_size"]
hidden_size = data["hidden_size"]
output_size = data["output_size"]
all_words = data['all_words']
tags = data['tags']

class ChatBotModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ChatBotModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(hidden_size, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(128, output_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

model = ChatBotModel(input_size, hidden_size, output_size)
model.load_state_dict(torch.load(f'{disease_name}.pth'))
model.eval()

while True:
    sentence = input("You: ") # Input 
    if sentence == "quit":
        break

    sentence = nltk.word_tokenize(sentence)
    X = bow(sentence, all_words)
    X = X.reshape(1, X.shape[0])
    X = torch.from_numpy(X)

    output = model(X)
    _, predicted = torch.max(output, dim=1)

    tag = tags[predicted.item()]

    probs = torch.softmax(output, dim=1)
    prob = probs[0][predicted.item()]
    for intent in intents['intents']:
        if tag == intent["tag"]:
            print(random.choice(intent['responses'])) # Output


The common symptoms of arthritis include pain, stiffness, swelling, and reduced range of motion in the joints.
