In [1]:
import torch
from torchvision import transforms
from PIL import Image
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# Configuration
config = {
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'save_path': './balanced_model_with_others.pth',
    'subclass_threshold': 0.7,
    'mainclass_threshold': 0.7,  
    'main_classes': ['Clothing', 'Bags', 'Shoes'],
    'subclass_to_idx': {
        'Dresses': 0,
        'Skirts': 1,
        'Outerwear': 2,
        'Shoulder Bags': 3,
        'Tote Bags': 4,
        'Clutches': 5,
        'High Heels': 8,
        'Boots': 7,
        'Flats': 6
    },
    'subclass_to_main': [
        0,  # Dresses → Clothing
        0,  # Skirts → Clothing
        0,  # Outerwear → Clothing
        1,  # Shoulder Bags → Bags
        1,  # Tote Bags → Bags
        1,  # Clutches → Bags
        2,  # High Heels → Shoes
        2,  # Boots → Shoes
        2   # Flats → Shoes
    ]
}

# Image preprocessing
transform = transforms.Compose([
    transforms.Lambda(lambda image: image.convert("RGB")),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load model
class HierarchicalResNet(torch.nn.Module):
    def __init__(self, num_main_classes, num_sub_classes):
        super().__init__()
        self.base = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        num_ftrs = self.base.fc.in_features
        self.base.fc = torch.nn.Identity()
        self.main_classifier = torch.nn.Linear(num_ftrs, num_main_classes)
        self.sub_classifier = torch.nn.Linear(num_ftrs, num_sub_classes)
        
    def forward(self, x):
        features = self.base(x)
        main_output = self.main_classifier(features)
        sub_output = self.sub_classifier(features)
        return main_output, sub_output

# Initialize model
num_main_classes = len(config['main_classes'])
num_sub_classes = len(config['subclass_to_idx'])
model = HierarchicalResNet(num_main_classes, num_sub_classes).to(config['device'])
model.load_state_dict(torch.load(config['save_path'], map_location=config['device']))
model.eval()

# Mapping for display
idx_to_subclass = {v: k for k, v in config['subclass_to_idx'].items()}

# Upload widget
uploader = widgets.FileUpload(
    accept='.jpg,.jpeg,.png',
    multiple=False,
    description='Upload Image'
)

# Output display
out = widgets.Output()
display(uploader, out)

# Classification button callback
def on_classify_click(b):
    with out:
        clear_output()
        if not uploader.value:
            print("Please upload an image first.")
            return
            
        # Get uploaded file content
        file_info = uploader.value[0]
        content = file_info['content']
        image = Image.open(io.BytesIO(content))

        try:
            # Display image
            image = Image.open(io.BytesIO(content))
            display(image.resize((224, 224)))
            
            # Preprocess and classify
            image_tensor = transform(image).unsqueeze(0).to(config['device'])
            
            with torch.no_grad():
                main_out, sub_out = model(image_tensor)
                main_probs = torch.softmax(main_out, dim=1)[0]
                sub_probs = torch.softmax(sub_out, dim=1)[0]
                
            main_conf, main_pred = torch.max(main_probs, 0)
            sub_conf, sub_pred = torch.max(sub_probs, 0)

            # Check if main class is too uncertain
            if main_conf.item() < config['mainclass_threshold']:
                print("\n=== Prediction Results ===")
                print(f"Main class confidence {main_conf.item():.4f} is lower than threshold {config['mainclass_threshold']}")
                print("Final classification: Unknown / Others")
                print("\nMain class probabilities:")
                for i, prob in enumerate(main_probs):
                    print(f"{config['main_classes'][i]}: {prob.item():.4f}")
                print("\nTop 3 subclass probabilities:")
                topk_values, topk_indices = torch.topk(sub_probs, 3)
                for val, idx in zip(topk_values, topk_indices):
                    sub_name = idx_to_subclass.get(idx.item(), f"Subclass_{idx.item()}")
                    print(f"{sub_name}: {val.item():.4f}")
                return

            # Continue with normal prediction logic
            sub_name = idx_to_subclass.get(sub_pred.item(), f"Unknown Subclass")
            belongs_to_main = config['subclass_to_main'][sub_pred.item()] == main_pred.item()
            meets_threshold = sub_conf.item() >= config['subclass_threshold']
            
            print("\n=== Prediction Results ===")
            print(f"Main class: {config['main_classes'][main_pred.item()]} (Confidence: {main_conf.item():.4f})")
            print(f"Subclass: {sub_name} (Confidence: {sub_conf.item():.4f})")
            
            if belongs_to_main and meets_threshold:
                print("\nDecision: Use subclass prediction")
                print(f"Final classification: {config['main_classes'][main_pred.item()]} -> {sub_name}")
            else:
                print("\nDecision: Fall back to main class prediction")
                reasons = []
                if not belongs_to_main:
                    reasons.append("Subclass does not belong to predicted main class")
                if not meets_threshold:
                    reasons.append(f"Subclass confidence {sub_conf.item():.4f} < threshold {config['subclass_threshold']}")
                print(f"Reason: {', '.join(reasons)}")
                print(f"Final classification: {config['main_classes'][main_pred.item()]}")
            
            print("\nMain class probabilities:")
            for i, prob in enumerate(main_probs):
                print(f"{config['main_classes'][i]}: {prob.item():.4f}")
            
            print("\nTop 3 subclass probabilities:")
            topk_values, topk_indices = torch.topk(sub_probs, 3)
            for val, idx in zip(topk_values, topk_indices):
                sub_name = idx_to_subclass.get(idx.item(), f"Subclass_{idx.item()}")
                print(f"{sub_name}: {val.item():.4f}")

        except Exception as e:
            print(f"Error processing image: {str(e)}")

# Classification button
classify_btn = widgets.Button(description="Classify Image")
classify_btn.on_click(on_classify_click)
display(classify_btn)


Using cache found in C:\Users\Annie/.cache\torch\hub\pytorch_vision_v0.10.0


FileUpload(value=(), accept='.jpg,.jpeg,.png', description='Upload Image')

Output()

Button(description='Classify Image', style=ButtonStyle())