# Loading AUGUST model with trained weights

This cell initializes the AUGUST model and loads pretrained weights from a checkpoint.

Notes and recommendations:
- Replace `"checkpoint_path"` with the actual path to your checkpoint file.
- `map_location=device` ensures tensors are loaded on the same device chosen earlier.
- After loading weights, move the model to the device and set it to eval mode for inference:
```python
model.to(device)
model.eval()
```

In [None]:
from august.models.model import AUGUST
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AUGUST(pretrained="checkpoint_path")
model.to(device)
model.eval()

# Autonomous Question Answering Sample
First, conduct three coarse-level tasks include: location, h. pylori and condition

In [None]:
coarse_questions = [
    "Identify both the general location (proximal vs. distal stomach) and the specific anatomical subregion of this slide?\nThe proximal stomach includes the cardia, fundus, and body, while the distal  stomach includes the antrum and prepylorus.\nPlease format your answer as: 'This slide represents the {general location}, specifically the {subregion}.'.",
    "Is this slide positive or negative for Helicobacter pylori infection?\nPlease format your answer as: 'This slide is {positive/negative} for Helicobacter pylori infection.'.",
    "Considering the spectrum of gastric pathology, determine the category of diagnosis represented in this slide: inflammatory disease, benign tumor, dysplasia, or cancer.\nPlease format your answer as: 'The category of diagnosis represented in this slide is {inflammatory disease/benign tumor/dysplasia/cancer}.'."
]
answer_list = []

In [None]:

features = torch.load("./august/data/sample_features.pt").to(device)
history = ""
for i in range(len(coarse_questions)):
    question = coarse_questions[i]
    question = history + question
    # Process each turn with the model
    with torch.no_grad():
        with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
            slide_embedding, question_embedding = model.get_slide_embedding(features, question)
            answer = model(slide_embedding, history, question_embedding)
            print("Question:", coarse_questions[i])
            print("Answer:", answer)
        history += f"{coarse_questions[i]}\n{answer}\n\n"
        answer_list.append(answer)

## Auto Diagnosis based on Condition

In [None]:
im_questions = [
    "Evaluate this slide diagnosed as inflammatory disease for the following histologic features and indicate which are present:\nSydney system: chronic inflammation, neutrophilic activity, glandular atrophy, intestinal metaplasia.\nAdditional features: erosion, ulceration.\nPlease format your answer as:\n* If any Sydney system features are present: 'Among the Sydney system features, this slide shows [chronic inflammation, neutrophilic activity, glandular atrophy, intestinal metaplasia]. Additionally, this slide shows {erosion/ulceration}.'.\n* If no Sydney system features are present: 'This slide shows nothing. Additionally, this slide shows [erosion,ulceration]'.\n* If no additional features (neither erosion nor ulceration) are present, please omit the \u201cAdditionally\u201d sentence.",
    "The following features are present in this slide: chronic inflammation, neutrophilic activity. Specify the inflammation grade for each as mild, moderate, or marked. Only evaluate chronic inflammation and neutrophilic activity if they are included in the list above.\nPlease format your answer as:'Chronic inflammation is {mild/moderate/marked}. Neutrophilic activity is {mild/moderate/marked}.'."
]

bn_question = "This slide is diagnosed with a benign tumor gastric lesion. Identify the specific histologic type observed.\n Please format your answer as: 'This slide shows {fundic gland polyp/hyperplastic polyp/inflammatory polyp/granulation tissue type polyp/xanthoma/gastritis cystica polyposa}.'."

dys_question = "This slide shows dysplastic changes. Determine the grade of dysplasia , indefinite, low grade, or high grade, and if it is low or high grade, specify the type of adenoma present.\nPlease format your answer as:\n* For low or high grade dysplasia: 'The dysplasia is {low grade/high grade}, {tubular adenoma/tubulovillous adenoma}.'\n* For indefinite dysplasia: 'The dysplasia is indefinite.'"

cc_questions = [
    "This slide shows malignant features. Determine the type of malignant tumor: carcinoma, malignant lymphoma, or NOS. If it is a carcinoma, also identify the histologic subtype.\nPlease format your answer as:\n* For malignant lymphoma or NOS: 'The malignant tumor is a {malignant lymphoma/NOS}.'\n* For carcinoma: 'The malignant tumor is a carcinoma, {adenocarcinoma/squamous cell carcinoma/neuroendocrine tumor}.'.",
    "Given the malignant tumor type carcinoma, adenocarcinoma, specify the relevant detail or additional condition.\nPlease format your answer as: 'For adenocarcinoma, the tumor is {well differentiated/moderately differentiated/poorly cohesive carcinoma}.",
    "Given the malignant tumor type malignant lymphoma, specify the relevant detail or additional condition.\nPlease format your answer as: 'For malignant lymphoma, the tumor is {NOS/extranodal marginal zone lymphoma of malt}'."
]

## Check the condition answer to obtain final groups questions
# condition_answer = answer_list[2]
# final_q = None
# if "inflammatory disease" in condition_answer:
#     final_q = im_questions
# elif "benign tumor" in condition_answer:
#     final_q = bn_question
# elif "dysplasia" in condition_answer:
#     final_q = dys_question
# else:
#     final_q = cc_questions

# EXAMPLE for All Conditions

## Inflammation Disease

In [None]:
for i in range(len(im_questions)):
    question = im_questions[i]
    question = history + question
    # Process each turn with the model
    with torch.no_grad():
        with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
            slide_embedding, question_embedding = model.get_slide_embedding(features, question)
            answer = model(slide_embedding, history, question_embedding)
            print("Question:", question)
            print("Answer:", answer)
        history += f"{im_questions[i]}\n{answer}\n\n"
        answer_list.append(answer)
    if "chronic" not in answer and "activity" not in answer:
        break
    if "chronic" not in answer:
        im_questions[1] = im_questions[1].replace("chronic inflammation, ", "")
    if "activity" not in answer:
        im_questions[1] = im_questions[1].replace(", neutrophilic activity", "")

## Bengin Tumor or Dysplasia

In [None]:
question = bn_question # or dys_question 
question = history + question
# Process each turn with the model
with torch.no_grad():
    with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
        slide_embedding, question_embedding = model.get_slide_embedding(features, question)
        answer = model(slide_embedding, history, question_embedding)
        print("Question:", question)
        print("Answer:", answer)
        answer_list.append(answer)

## Cancer

In [None]:

cancer_types = cc_questions[0]
question = history + cancer_types
with torch.no_grad():
    with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
        slide_embedding, question_embedding = model.get_slide_embedding(features, question)
        answer = model(slide_embedding, history, question_embedding)
        print("Question:", cancer_types)
        print("Answer:", answer)
    history += f"{cancer_types}\n{answer}\n\n"
    answer_list.append(answer)
    if "adenocarcinoma" in answer:
        gradings = cc_questions[1]
    elif "malignant lymphoma" in answer:
        gradings = cc_questions[2]
    question = history + gradings
    with torch.autocast('cuda', torch.bfloat16), torch.inference_mode():
        slide_embedding, question_embedding = model.get_slide_embedding(features, question)
        answer = model(slide_embedding, history, question_embedding)
        print("Question:", gradings)
        print("Answer:", answer)
