In [1]:
!pip install transformers==4.40.0

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import os
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms as T
from transformers import AutoConfig, AutoModel
import pandas as pd

In [2]:
IMAGE_PATH = '/home/jupyter-nafisha/X-ray/Inference_data/Chexpert/patient00314/study8/view1_frontal.jpg'  
CHECKPOINT = "/home/common/checkpoints/stage3_best.pt"

In [3]:
preprocess = T.Compose([
    T.Resize((512, 512), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(
        mean=[0.48145466, 0.4578275,  0.40821073],
        std =[0.26862954, 0.26130258, 0.27577711]
    ),
])

In [4]:
df = pd.read_csv("/home/jupyter-nafisha/X-ray/Nikita.csv")
label_cols = df.columns.tolist()[1:]    # ["Atelectasis", ..., "Pleural Effusion"]
num_labels = len(label_cols)

In [5]:
class CXRMultiLabel(nn.Module):
    def __init__(self, vision_model, num_labels, pos_weight):
        super().__init__()
        self.vision = vision_model
        in_dim = vision_model.config.hidden_size
        self.head = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_labels)
        )
        self.pos_weight = pos_weight

    def forward(self, pixel_values):
        out = self.vision(pixel_values=pixel_values, return_dict=True)
        cls = out.last_hidden_state[:, 0, :]
        logits = self.head(cls)
        return logits

In [6]:
vision_cfg = AutoConfig.from_pretrained(
    "StanfordAIMI/XraySigLIP__vit-l-16-siglip-384__webli",
    trust_remote_code=True
)
vision_full = AutoModel.from_pretrained(
    "StanfordAIMI/XraySigLIP__vit-l-16-siglip-384__webli",
    config=vision_cfg,
    trust_remote_code=True
)
vision_encoder = vision_full.vision_model
del vision_full

pos_weight = torch.ones(num_labels)   # NOT used during inference

model = CXRMultiLabel(
    vision_model=vision_encoder,
    num_labels=num_labels,
    pos_weight=pos_weight
)



In [7]:
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only= False)
model.load_state_dict(ckpt["model_state"])
model.eval()

CXRMultiLabel(
  (vision): SiglipVisionTransformer(
    (embeddings): SiglipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16), padding=valid)
      (position_embedding): Embedding(1024, 1024)
    )
    (encoder): SiglipEncoder(
      (layers): ModuleList(
        (0-23): 24 x SiglipEncoderLayer(
          (self_attn): SiglipAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): SiglipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bia

In [8]:
sigmoid = nn.Sigmoid()

In [9]:
def infer_single_image(image_path):
    img = Image.open(image_path).convert("RGB")
    pv = preprocess(img).unsqueeze(0)     # [1,3,512,512]

    with torch.no_grad():
        logits = model(pv)
        probs  = sigmoid(logits).cpu().numpy()[0]

    preds = (probs >= 0.5).astype(int)

    # Return label â†’ (probability, prediction)
    results = { label_cols[i]: (float(probs[i]), int(preds[i])) for i in range(num_labels) }
    return results

In [10]:
pwd

'/home/jupyter-nafisha/X-ray'

In [11]:
import pandas as pd
csv_path= '/home/jupyter-nafisha/X-ray/CSVs/test.csv'
test_data= pd.read_csv(csv_path)

In [16]:
data_dir = '/home/jupyter-nafisha/X-ray/Data'

records = []

count=0

for idx, row in test_data.iterrows():
    file_path = os.path.join(data_dir, row['image_id'])
    # output = {}
    output = infer_single_image(file_path)
    
    diseases = []
    for disease, (prob, pred) in output.items():
        if prob > 0.5:
            diseases.append(disease)
    
    class_name = ", ".join(diseases) if diseases else "No FInding"
    
    records.append({
        "image_id": row["image_id"],
        "class_name": class_name
    })

    count+=1

    if count==5:
        break

# Create the new dataframe
new_df = pd.DataFrame(records)

In [19]:
ls

best_model.pth   [0m[01;34mData[0m/            Nikita.csv      test_predictions.csv
[01;34mcheckpoints[0m/     dataset.py       predict.py      train.py
chexAgent.ipynb  [01;34mInference_data[0m/  [01;34m__pycache__[0m/    transforms.py
CSV.ipynb        last_model.pth   stage3_best.pt  utils.py
[01;34mCSVs[0m/            model.py         Testing.ipynb


In [None]:
new_df.to_csv('/home/jupyter-nafisha/X-ray/CSVs/chexAgent_output.csv')