In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2
import matplotlib.pyplot as plt
import os

# ================ ATTENTION UNET MODEL ================
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        s = self.conv(x)
        p = self.pool(s)
        return s, p

class attention_gate(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.Wg = nn.Sequential(
            nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.Ws = nn.Sequential(
            nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)
        self.output = nn.Sequential(
            nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, g, s):
        Wg = self.Wg(g)
        Ws = self.Ws(s)
        out = self.relu(Wg + Ws)
        out = self.output(out)
        return out * s

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.ag = attention_gate(in_c, out_c)
        self.c1 = conv_block(in_c[0]+out_c, out_c)

    def forward(self, x, s):
        x = self.up(x)
        s = self.ag(x, s)
        x = torch.cat([x, s], dim=1)
        x = self.c1(x)
        return x

class attention_unet(nn.Module):
    def __init__(self):
        super().__init__()
        # Change input channels from 3 to 1 for grayscale images
        self.e1 = encoder_block(1, 64)  # Modified line
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.b1 = conv_block(256, 512)
        self.d1 = decoder_block([512, 256], 256)
        self.d2 = decoder_block([256, 128], 128)
        self.d3 = decoder_block([128, 64], 64)
        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)

        # Add final activation
        # self.final_activation = nn.Sigmoid()

    def forward(self, x):
        s1, p1 = self.e1(x)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)

        b1 = self.b1(p3)
        d1 = self.d1(b1, s3)
        d2 = self.d2(d1, s2)
        d3 = self.d3(d2, s1)
        output = self.output(d3)
        return output


# ================== STRUCTURED REPORT GENERATOR ==================
class StructuredReportGenerator:
    def __init__(self, unet_model_path, device='cpu', pixel_to_mm=0.1):
        self.device = torch.device(device)
        self.unet = self._load_unet(unet_model_path)
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])
        self.pixel_to_mm = pixel_to_mm

    def _load_unet(self, model_path):
        model = attention_unet()
        state_dict = torch.load(model_path, map_location=self.device)

        # Handle DataParallel wrapper
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v
            else:
                new_state_dict[k] = v

        model.load_state_dict(new_state_dict)
        model.eval().to(self.device)
        return model

    def predict_mask(self, image_tensor, threshold=0.5):
        with torch.no_grad():
            output = self.unet(image_tensor.unsqueeze(0).to(self.device))
            mask = torch.sigmoid(output).cpu().squeeze(0).squeeze(0).numpy()
            return (mask > threshold).astype(np.uint8) * 255

    def extract_features(self, original_img, mask):
        orig_img_array = np.array(original_img)
        binary_mask = (mask > 0).astype(np.uint8)

        # Resize mask to match original dimensions
        orig_h, orig_w = orig_img_array.shape
        resized_mask = cv2.resize(
            binary_mask,
            (orig_w, orig_h),
            interpolation=cv2.INTER_NEAREST
        )

        features = {}
        contours, _ = cv2.findContours(
            resized_mask,
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )

        if contours:
            largest_contour = max(contours, key=cv2.contourArea)
            pixel_area = cv2.contourArea(largest_contour)
            features['area_mm'] = pixel_area * (self.pixel_to_mm ** 2)

            x, y, w, h = cv2.boundingRect(largest_contour)
            features['aspect_ratio'] = w / max(h, 1e-5)

            masked_region = orig_img_array * resized_mask
            pixels = masked_region[resized_mask > 0]

            if pixels.size > 0:
                mean_intensity = np.mean(pixels)
                if mean_intensity < 100:
                    features['echogenicity'] = "hypoechoic"
                elif mean_intensity < 180:
                    features['echogenicity'] = "isoechoic"
                else:
                    features['echogenicity'] = "hyperechoic"
            else:
                features['echogenicity'] = "undetermined"

        return features

    def generate_structured_report(self, features):
        birads_category = "Category 4: Suspicious abnormality - biopsy recommended"
        if features.get('aspect_ratio', 0) < 1.4 and features.get('area_mm', 0) < 50:
            birads_category = "Category 3: Probably benign - short-term follow-up suggested"

        recommendations = "Ultrasound-guided core needle biopsy for histopathological correlation."
        if "Category 3" in birads_category:
            recommendations = "Follow-up ultrasound in 6 months recommended."

        return (
            "1. Clinical Findings: "
            f"Irregular {features.get('echogenicity', 'hypoechoic')} mass "
            f"measuring {features.get('area_mm', 0):.2f} mm² with spiculated margins.\n"
            "2. Morphological Description: "
            f"Lesion demonstrates irregular contours (aspect ratio: {features.get('aspect_ratio', 0):.2f}) "
            "and heterogeneous internal echotexture.\n"
            "3. Echogenicity Assessment: "
            f"Predominantly {features.get('echogenicity', 'hypoechoic')} "
            "with focal hyperechoic components.\n"
            f"4. BI-RADS Classification: {birads_category}\n"
            f"5. Clinical Recommendations: {recommendations}"
        )

    def process_scan(self, image_path, threshold=0.5):
        original_img = Image.open(image_path).convert('L')
        img_tensor = self.transform(original_img)
        mask = self.predict_mask(img_tensor, threshold)
        features = self.extract_features(original_img, mask)
        report = self.generate_structured_report(features)

        return {
            "original_scan": np.array(original_img),
            "segmentation_mask": mask,
            "structured_report": report
        }

    def display_results(self, image_path, threshold=0.5):
        results = self.process_scan(image_path, threshold)

        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        plt.imshow(results['original_scan'], cmap='gray')
        plt.title('Original Scan')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(results['segmentation_mask'], cmap='gray')
        plt.title('Segmentation Mask')
        plt.axis('off')

        # plt.subplot(1, 3, 3)
        # plt.text(0.05, 0.5, results['structured_report'],
        #          fontsize=10, ha='left', va='center', wrap=True)
        # plt.axis('off')
        # plt.title('Structured Report')

        plt.tight_layout()
        plt.show()

        return results

# ================ USAGE EXAMPLE ================
if __name__ == "__main__":
    report_generator = StructuredReportGenerator(
        unet_model_path='/kaggle/input/aunet/pytorch/default/1/attention_unet_busi.pth',
        pixel_to_mm=0.1,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )

    scan_path = '/kaggle/input/busi-malignant/malignant/malignant (108).png'
    results = report_generator.display_results(scan_path)

    print("\n" + "="*40)
    print("STRUCTURED REPORT".center(40))
    print("="*40)
    print(results['structured_report'])


FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/aunet/pytorch/default/1/attention_unet_busi.pth'

In [None]:
!pip install sacremoses


Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: sacremoses
Successfully installed sacremoses-0.1.1
