In [2]:
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
from transformers import BartModel, BartTokenizer
from torchvision import models
import pandas as pd

# Define the MultiModalClassifier class
class MultiModalClassifier(nn.Module):
    def __init__(self, text_model, image_model, text_feat_dim, image_feat_dim, hidden_dim, num_classes):
        super(MultiModalClassifier, self).__init__()
        self.text_model = text_model
        self.image_model = image_model

        # Linear layers to project modality-specific features to a common hidden space
        self.text_fc = nn.Linear(text_feat_dim, hidden_dim)
        self.image_fc = nn.Linear(image_feat_dim, hidden_dim)

        # Final classifier head that outputs logits for each class
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, text_input=None, image_input=None):
        features = None

        if text_input is not None:
            # Remove 'labels' if present so that BART doesn't receive an unexpected key
            text_input_filtered = {k: v for k, v in text_input.items() if k != "labels"}
            text_outputs = self.text_model(**text_input_filtered)
            # Mean pooling over the sequence dimension (dim=1)
            pooled_text = text_outputs.last_hidden_state.mean(dim=1)
            text_features = self.text_fc(pooled_text)
            features = text_features if features is None else features + text_features

        if image_input is not None:
            # The image model should output a feature vector (with final fc replaced by Identity)
            image_features = self.image_model(image_input)
            image_features = self.image_fc(image_features)
            features = image_features if features is None else features + image_features

        # If both modalities are provided, average their features
        if (text_input is not None) and (image_input is not None):
            features = features / 2

        logits = self.classifier(features)
        return logits

# Load the tokenizer
model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)

# Define image transformations for inference
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CSV file to get the label list
text_df = pd.read_csv("dataset.csv")
label_list = text_df['labels'].unique().tolist()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "multimodal_model.pth"

# Load the pre-trained ResNet18 model and replace its fc layer with Identity
image_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
image_model.fc = nn.Identity()

# Create the MultiModalClassifier using the modified image model
model = MultiModalClassifier(
    text_model=BartModel.from_pretrained(model_name),
    image_model=image_model,
    text_feat_dim=768,   # typically 768 for bart-base
    image_feat_dim=512,  # typically 512 for resnet18 after replacing fc
    hidden_dim=512,
    num_classes=len(label_list)
)

# Load the saved state dict
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Mapping from label IDs to label names
label2id = {label: idx for idx, label in enumerate(label_list)}
id2label = {idx: label for label, idx in label2id.items()}

def inference_text(model, tokenizer, text, device, max_length=128):
    """Perform inference on text input only."""
    model.eval()
    encoding = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    # Move encoding to device
    for key in encoding:
        encoding[key] = encoding[key].to(device)
    with torch.no_grad():
        logits = model(text_input=encoding, image_input=None)
    pred_id = torch.argmax(logits, dim=1).item()
    return id2label[pred_id]

def inference_image(model, image_path, transform, device):
    """Perform inference on an image input only."""
    model.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # add batch dimension
    image = image.to(device)
    with torch.no_grad():
        logits = model(text_input=None, image_input=image)
    pred_id = torch.argmax(logits, dim=1).item()
    return id2label[pred_id]

def inference_both(model, tokenizer, text, image_path, transform, device, max_length=128):
    """Perform inference using both text and image inputs."""
    model.eval()
    # Process text input
    encoding = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    for key in encoding:
        encoding[key] = encoding[key].to(device)
    # Process image input
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    image = image.to(device)
    with torch.no_grad():
        logits = model(text_input=encoding, image_input=image)
    pred_id = torch.argmax(logits, dim=1).item()
    return id2label[pred_id]

# Example inference usage
sample_text = "Fatigue, Fever, Congestion, Body aches, Shortness of breath"
sample_image_path = r"F:\ABDUL\ABDUL 2024\LUNG DISEASE_MULITE_MODEL_AI\images\train\COVID19\COVID19(483).jpg"

# Text inference
predicted_label_text = inference_text(model, tokenizer, sample_text, device)
print(f"Predicted label from text: {predicted_label_text}")

# Image inference
predicted_label_image = inference_image(model, sample_image_path, image_transforms, device)
print(f"Predicted label from image: {predicted_label_image}")

# Combined inference
predicted_label_both = inference_both(model, tokenizer, sample_text, sample_image_path, image_transforms, device)
print(f"Predicted label from both modalities: {predicted_label_both}")


AttributeError: partially initialized module 'torch' has no attribute 'types' (most likely due to a circular import)