In [20]:
import os
import torch
import pandas as pd
import numpy as np
import torchvision
import torch.nn as nn
from torch.utils.data import Dataset, random_split, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from collections import OrderedDict
import zipfile
%matplotlib inline
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
# Download url of normal CT scans.
url_normal = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"
filename_normal = os.path.join(os.getcwd(), "CT-0.zip")
torchvision.datasets.utils.download_and_extract_archive(url_normal, filename_normal)

# Download url of abnormal CT scans.
url_abnormal = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
filename_abnormal = os.path.join(os.getcwd(), "CT-23.zip")
torchvision.datasets.utils.download_and_extract_archive(url_abnormal, filename_abnormal)


Downloading https://objects.githubusercontent.com/github-production-release-asset-2e65be/296429475/b717cc00-fe6a-11ea-8c3a-a7c0583602e5?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240208%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240208T155515Z&X-Amz-Expires=300&X-Amz-Signature=71d4d04e3b549ad582f88577d3e9f02f3fa65f2c167f771d9d343a7b341b5b39&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=296429475&response-content-disposition=attachment%3B%20filename%3DCT-0.zip&response-content-type=application%2Foctet-stream to /content/CT-0.zip/CT-0.zip


100%|██████████| 1065471431/1065471431 [00:08<00:00, 132594599.03it/s]


Extracting /content/CT-0.zip/CT-0.zip to /content/CT-0.zip
Downloading https://objects.githubusercontent.com/github-production-release-asset-2e65be/296429475/4deebd80-00e3-11eb-961e-4dae6b94b040?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240208%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240208T155528Z&X-Amz-Expires=300&X-Amz-Signature=96b8eb3c1b17b509e0ec028655fe5f4ee1118f9d09ee847da250b45d4785e6b1&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=296429475&response-content-disposition=attachment%3B%20filename%3DCT-23.zip&response-content-type=application%2Foctet-stream to /content/CT-23.zip/CT-23.zip


100%|██████████| 1045162547/1045162547 [00:07<00:00, 140285304.01it/s]


Extracting /content/CT-23.zip/CT-23.zip to /content/CT-23.zip


In [52]:
class NiftiDataset(Dataset):
    def __init__(self, root_dir, file_pattern, num_samples=None):
        self.root_dir = root_dir
        self.file_pattern = file_pattern
        self.num_samples = num_samples

    def __len__(self):
        return len(os.listdir(self.root_dir))

    def __getitem__(self, idx):
        filepath = os.path.join(self.root_dir, self.file_pattern % idx)
        # Load the NIfTI file
        nii = nib.load(filepath).obj

        # Extract the image data
        image = nii.get_fdata().str() % idx

        # Convert the image data to a tensor
        image_tensor = torch.from_numpy(image).float()

        # Normalize the pixel values to be between -1 and 1
        image_tensor = image_tensor / image_tensor.max()

        # Return the image tensor and its corresponding label
        return image_tensor, np.random.randint(0, 2, size=1)  # Replace with actual label information


In [53]:
class NormalCellDataset(NiftiDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.root_dir = '/content/CT-0.zip/CT-0'
        self.file_pattern = 'image_{}.nii.gz'

class AbnormalCellDataset(NiftiDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.root_dir = '/content/CT-23.zip/CT-23'
        self.file_pattern = 'image_{}.nii.gz'


In [54]:
normal_dataset = NormalCellDataset(root_dir='/content/CT-0.zip/CT-0', file_pattern='image_{}.nii.gz', num_samples=1000)
abnormal_dataset = AbnormalCellDataset(root_dir='/content/CT-23.zip/CT-23', file_pattern='image_{}.nii.gz', num_samples=1000)

In [55]:
# Create data loaders for the datasets
normal_loader = DataLoader(normal_dataset, batch_size=32, shuffle=True)
abnormal_loader = DataLoader(abnormal_dataset, batch_size=32, shuffle=True)


In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ThreeCnn(nn.Module):
    def __init__(self, width=128, height=128, depth=64):
        super(ThreeCnn, self).__init__()
        self.cellfirst = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU())

        self.cellsecond = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU())

        self.cellthird = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU())

        self.cellforth = nn.Sequential(
            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.BatchNorm3d(num_features=256),
            nn.ReLU())

        self.final = nn.Sequential(
            nn.AvgPool3d((1, 1, 1)),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1),
            nn.Sigmoid())

    def forward(self, x):
      x = self.cellfirst(x)
      x = self.cellsecond(x)
      x = self.cellthird(x)
      x = self.cellforth(x)
      x = self.final(x)
      x = x.view(x.size(0), -1)
      return x



In [57]:
model = ThreeCnn()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

ThreeCnn(
  (cellfirst): Sequential(
    (0): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
  )
  (cellsecond): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
  )
  (cellthird): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU()
    (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU()
  )
  (cellforth): Sequential(
 

In [58]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
num_epochs = 10

In [59]:
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(normal_loader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch {}: Loss = {:.4f}'.format(epoch+1, running_loss/(i+1)))

    running_loss = 0.0
    for i, data in enumerate(abnormal_loader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch {}: Loss = {:.4f}'.format(epoch+1, running_loss/(i+1)))

AttributeError: 'NormalCellDataset' object has no attribute 'get_fdata'