## **挂载 Drive + 切到项目目录**

In [3]:
from google.colab import drive
drive.mount('/content/drive')

import os, sys

BASE_DIR = "/content/drive/MyDrive/BIA4/classification"
os.chdir(BASE_DIR)
sys.path.append(BASE_DIR)

print("CWD:", os.getcwd())
print("Sub-dirs:", os.listdir())


Mounted at /content/drive
CWD: /content/drive/MyDrive/BIA4/classification
Sub-dirs: ['README.md', '.gitignore', 'utils.py', 'LICENSE', 'test_mpidb_dataloader.py', 'test_model_forward.py', 'main.py', 'data', 'model', 'lightning_logs', 'Train_Validate_Script.ipynb']


## **安装依赖**

In [1]:
!pip3 uninstall --yes torch torchaudio torchvision torchtext torchdata
!pip3 install torch torchaudio torchvision torchtext torchdata

!pip install -q \
    pytorch-lightning==2.4.0 \
    albumentations==1.4.18 \
    opencv-python \
    numpy \
    torchmetrics \
    tqdm \
    pandas \
    scikit-learn \
    tensorboard \
    optuna \
    torchviz


Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
[0mFound existing installation: torchdata 0.11.0
Uninstalling torchdata-0.11.0:
  Successfully uninstalled torchdata-0.11.0
Collecting torch
  Downloading torch-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchaudio
  Downloading torchaudio-2.9.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
Collecting torchvision
  Downloading torchvision-0.24.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting torchtext
  Downloading torchtext-0.18.0-cp312-cp312-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting torchdata
  Downloading torchdata

In [2]:
# Testing the Package Loading
import torch, pytorch_lightning as pl, albumentations
import cv2, numpy as np
from torchmetrics.classification import MulticlassAccuracy

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("lightning:", pl.__version__)
print("albumentations:", albumentations.__version__)


torch: 2.9.0+cu128
cuda available: True
lightning: 2.4.0
albumentations: 1.4.18


  check_for_updates()


## 调用数据库和模型

In [11]:
import os, sys
print("Before CWD:", os.getcwd())
!ls
%cd /content/drive/MyDrive/BIA4/classification

print("Now CWD:", os.getcwd())
print("Files here:")
!ls

project_root = os.getcwd()
if project_root not in sys.path:
    sys.path.append(project_root)

# # Import Database Class and Test
# from data import DInterface
# data_root = os.path.join(project_root, "data", "MPIDB")
# print("Data root:", data_root)

# dm = DInterface(
#     num_workers=4,
#     dataset="mpidb_dataset",
#     batch_size=64,
#     pin_memory=True,
#     root=data_root,
#     classes=["falciparum", "vivax", "ovale"],
#     img_size=100,
#     aug=True,
# )

# dm.setup("fit")
# train_loader = dm.train_dataloader()
# val_loader = dm.val_dataloader()

# xb, yb = next(iter(train_loader))
# print("Train batch:", xb.shape, yb.shape, "labels:", yb[:8])

# xbv, ybv = next(iter(val_loader))
# print("Val batch:  ", xbv.shape, ybv.shape, "labels:", ybv[:8])

# # Import Model Class and Test
# from model.standard_net import StandardNetLightning
# from data import DInterface

# dm = DInterface(
#     num_workers=4,
#     dataset='mpidb_dataset',
#     batch_size=8,
#     root='data/MPIDB',
#     classes=['falciparum','vivax','ovale'],
#     img_size=100,
#     aug=True,
# )

# dm.setup('fit')
# model = StandardNetLightning(in_channels=7, num_classes=3)

# xb, yb = next(iter(dm.train_dataloader()))
# logits = model(xb)
# print("logits:", logits.shape)
# #[8, 3]


Before CWD: /content/drive/MyDrive/BIA4/classification
data		main.py    test_model_forward.py	utils.py
LICENSE		model	   test_mpidb_dataloader.py
lightning_logs	README.md  Train_Validate_Script.ipynb
/content
Now CWD: /content/drive/MyDrive/BIA4/classification
Files here:
data		main.py    test_model_forward.py	utils.py
LICENSE		model	   test_mpidb_dataloader.py
lightning_logs	README.md  Train_Validate_Script.ipynb


## Train

In [21]:
import pytorch_lightning as pl
pl.seed_everything(42, workers=True)
from data import DInterface

dm = DInterface(
    num_workers=4,
    dataset='mpidb_dataset',
    batch_size=32,
    root='data/MPIDB',
    classes=['falciparum', 'vivax', 'ovale'],
    img_size=100,
    aug=True,
)

from model.standard_net import StandardNetLightning

model = StandardNetLightning(
    in_channels=7,
    num_classes=3,
    lr=5e-4,
    weight_decay=1e-4,
    dropout=0.3,
)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model ready. Params: {total_params/1e6:.3f} M")


import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger(
    save_dir="lightning_logs",
    name="mpidb_7ch_cnn",
)

ckpt_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    filename="mpidb-7ch-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}",
    save_weights_only=False,
)


early_stop = EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=8,
    verbose=True,
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")

accelerator = "gpu" if torch.cuda.is_available() else "cpu"
devices = 1

print("Using accelerator:", accelerator)

trainer = pl.Trainer(
    max_epochs=30,
    accelerator=accelerator,
    devices=devices,
    logger=logger,
    callbacks=[ckpt_callback, early_stop, lr_monitor],
    log_every_n_steps=10,
    deterministic=True,
)


INFO:lightning_fabric.utilities.seed:Seed set to 42
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Model ready. Params: 2.750 M
Using accelerator: gpu


In [None]:
trainer.fit(model, datamodule=dm)

/usr/local/lib/python3.12/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory lightning_logs/mpidb_7ch_cnn/version_3/checkpoints exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | MPIDBCNN           | 2.8 M  | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
4 | test_acc  | MulticlassAccuracy | 0      | train
---------------------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.001    Total estimated model params size (MB)
35        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [19]:
from model.standard_net import StandardNetLightning
BASE_DIR = "/content/drive/MyDrive/BIA4/classification"
os.chdir(BASE_DIR)
ckpt_path = "/content/drive/MyDrive/BIA4/classification/lightning_logs/mpidb_7ch_cnn/version_3/checkpoints/mpidb-7ch-epoch=22-val_loss=0.6395-val_acc=0.6400.ckpt"
best_model = StandardNetLightning.load_from_checkpoint(ckpt_path)
best_model.eval()


MPIDBCNNLightning(
  (model): MPIDBCNN(
    (b1): Sequential(
      (0): Sequential(
        (0): Conv2d(7, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): Dropout(p=0.3, inplace=False)
    )
    (b2): Sequential(
      (0): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.1, inplace=True)
      )
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): Dropout(p=0.3, inplace=False)
    )
    (b3): Sequential(
      (0): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)

In [22]:

test_results = trainer.test(best_model, datamodule=dm)
print(test_results)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.72015380859375, 'test_acc': 0.6428571343421936}]


In [23]:
import torch
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

best_model.eval()
best_model.to("cuda" if torch.cuda.is_available() else "cpu")

all_preds, all_labels = [], []
device = next(best_model.parameters()).device

for xb, yb in dm.test_dataloader():
    xb = xb.to(device)
    yb = yb.to(device)
    with torch.no_grad():
        logits = best_model(xb)
        preds = logits.argmax(1)
    all_preds.append(preds.cpu().numpy())
    all_labels.append(yb.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)

print("Confusion matrix:\n", confusion_matrix(all_labels, all_preds))
print("\nReport:\n", classification_report(all_labels, all_preds,
                                           target_names=['falciparum','vivax','ovale']))


Confusion matrix:
 [[12  5  0]
 [ 0  6  0]
 [ 3  2  0]]

Report:
               precision    recall  f1-score   support

  falciparum       0.80      0.71      0.75        17
       vivax       0.46      1.00      0.63         6
       ovale       0.00      0.00      0.00         5

    accuracy                           0.64        28
   macro avg       0.42      0.57      0.46        28
weighted avg       0.58      0.64      0.59        28



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
