# ## 1. Title & Description
**Lung Nodule Detection EDA**  
This notebook explores the LIDC-IDRI data:  
- Class balance  
- Example CT slices  
- Intensity distributions  

# ## 2. Imports

In [None]:
# %%
import os
import glob
import pandas as pd
import numpy as np
import pydicom
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

# set plots inline
%matplotlib inline

# ## 3. Load Annotations & Class Balance

In [None]:
# adjust path if you used a different name
ann = pd.read_csv('../data/annotations.csv')
print(ann.head())

# class counts
counts = ann['label'].value_counts().rename(index={0:'No Nodule',1:'Nodule'})
print(counts)
counts.plot.pie(autopct='%.1f%%', title='Class Balance', ylabel='')

## 4. Sample CT Slice Visualization

In [None]:
def load_central_slice(patient_id, data_dir='../data/LIDC-IDRI'):
    folder = os.path.join(data_dir, str(patient_id))
    files = glob.glob(os.path.join(folder, '*.dcm'))
    slices = []
    for fp in files:
        ds = pydicom.dcmread(fp)
        slices.append((ds.InstanceNumber, ds.pixel_array.astype(np.float32)))
    slices = [s for _, s in sorted(slices, key=lambda x: x[0])]
    vol = np.stack([s for _, s in slices], axis=0)
    img = vol[len(vol)//2]
    # normalize to [0,1]
    img = (img - img.min())/(img.max()-img.min())
    return img

# show 4 positives and 4 negatives
fig, axes = plt.subplots(2,4, figsize=(12,6))
pos_ids = ann[ann.label==1].patientid.values[:4]
neg_ids = ann[ann.label==0].patientid.values[:4]

for ax, pid in zip(axes[0], pos_ids):
    ax.imshow(load_central_slice(pid), cmap='gray')
    ax.set_title(f'Patient {pid}\nLabel=1')
    ax.axis('off')

for ax, pid in zip(axes[1], neg_ids):
    ax.imshow(load_central_slice(pid), cmap='gray')
    ax.set_title(f'Patient {pid}\nLabel=0')
    ax.axis('off')

## 5. Pixel-Intensity Histogram

In [None]:
# gather intensities from a small sample to avoid huge memory usage
samples = pos_ids.tolist() + neg_ids.tolist()
pixels = []
for pid in tqdm(samples):
    img = load_central_slice(pid)
    pixels.extend(img.flatten())
    
plt.figure(figsize=(8,4))
plt.hist(pixels, bins=50, alpha=0.7)
plt.title('Intensity Distribution (Sample)')
plt.xlabel('Normalized intensity')
plt.ylabel('Frequency')

## 6. Quick Training Sanity Check 

In [None]:
import torch
from torch.utils.data import DataLoader
from src.dataloader import LIDCDataset
from src.model import NoduleClassifier

# small one-batch test
ds = LIDCDataset('../data/LIDC-IDRI', '../data/annotations.csv', transforms=None)
dl = DataLoader(ds, batch_size=4, shuffle=True)
imgs, labels = next(iter(dl))
model = NoduleClassifier(pretrained=False)
out = model(imgs)
print('Output shape:', out.shape)  # should be [4,2]