In [97]:
import torch 

torch.cuda.is_available()

True

In [98]:
from torch import nn

class MedicalCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(MedicalCNN, self).__init__()
        
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv_3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        
        self.fully_connected_1 = nn.Linear(128 * 32 * 32, 512)
        self.fully_connected_2 = nn.Linear(512, num_classes)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=.5)
        
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv_1(x)))
        x = self.pool(self.relu(self.conv_2(x)))
        x = self.pool(self.relu(self.conv_3(x)))
        
        x = x.view(-1, 128 * 32 * 32)
        x = self.relu(self.fully_connected_1(x))
        x = self.dropout(x)
        x = self.fully_connected_2(x)
        
        return x

In [99]:
import pandas as pd

df = pd.read_csv('data/missing_data.csv')

df

Unnamed: 0.1,dicom_id,subject_id,study_id,PerformedProcedureStepDescription,ViewPosition,Rows,Columns,StudyDate,StudyTime,ProcedureCodeSequence_CodeMeaning,...,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices,Airspace Opacity,Unnamed: 0,Image_Path
0,00066c6b-67e23e14-d4dbe574-c1740091-bd4e50c6,11982346,54243900,CHEST (PA AND LAT),,2140,1760,21650425,135606.000,CHEST (PA AND LAT),...,,1.0,,,,,,,236635,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
1,0006f794-93547e63-3a14d1d3-486c5c6c-6bbac987,13679217,58086261,CHEST (PA AND LAT),,2140,1760,21580429,132922.000,CHEST (PA AND LAT),...,1.0,,1.0,,1.0,,,,121753,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
2,0006ffca-fee7bc9c-bb4e3942-4e61b867-7e77af78,10137100,57298029,,,2140,1760,21280703,150737.000,CHEST (PA AND LAT),...,,,,,1.0,,,,102037,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
3,000afff7-02b1aca6-1646f6cb-9da6450c-4e23b7f1,12460613,59177929,CHEST (PA AND LAT),,2140,1760,21671122,153844.000,CHEST (PA AND LAT),...,,1.0,0.0,,0.0,0.0,,,133174,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
4,0010318e-7d5baf9d-075dcc7f-b18f9fcc-9bb36faa,19777911,57672230,CHEST (PORTABLE AP),,2140,1760,21650817,211552.000,CHEST (PORTABLE AP),...,1.0,,,,-1.0,,1.0,,116290,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15764,ffea2e64-f7c0ec82-b1aabcbb-73303eed-9026d7c6,13051530,59235485,CHEST (PA AND LAT),,2140,1760,21760117,105126.000,CHEST (PA AND LAT),...,,,1.0,,,0.0,0.0,,135322,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
15765,fff524dd-c2926e8b-19704510-295b8889-63132157,17289501,59670483,CHEST (PA AND LAT),,2140,1760,21370509,91043.000,CHEST (PA AND LAT),...,,,0.0,,,0.0,,,161914,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
15766,fff6e8d3-b6118442-d3b803ea-0d4bfc82-3669c4e8,16624064,56282440,CHEST (PA AND LAT),,2140,1760,21550313,93406.758,CHEST (PA AND LAT),...,,1.0,,,0.0,,,,63638,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...
15767,fffaca7b-b858376c-8d540985-db0db276-c1bbcb1b,16970933,51602986,CHEST (PA AND LAT),,2140,1760,21960409,93558.000,CHEST (PA AND LAT),...,,1.0,,,,,,,297696,/media/mohammad/Vir2_Pre/NLP Project/NLP_data_...


In [100]:
from torch.utils.data import Dataset
from PIL import Image

class MedicalDataset(Dataset):
    def __init__(self, csv_dataset, transform=None):
        super(MedicalDataset, self).__init__()
        self.data = csv_dataset
        self.transform = transform
        

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        image_path = row['Image_Path']
        image = Image.open(image_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
            
        image = image.to('cuda')
        
        
        return image

In [101]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

dataset = MedicalDataset(df[7000:8000], transform)

In [102]:
from torch.utils.data import DataLoader

dataset = DataLoader(dataset, batch_size=32, shuffle=False)

In [103]:


model = MedicalCNN(num_classes=4).to('cuda')
model_state = torch.load('/media/mohammad/Vir2_Pre/NLP Project/nlp code/CNN/model_1.pth', weights_only=False)
model.load_state_dict(model_state)
model.eval()



MedicalCNN(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fully_connected_1): Linear(in_features=131072, out_features=512, bias=True)
  (fully_connected_2): Linear(in_features=512, out_features=4, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

In [104]:
df2 = pd.DataFrame(columns=["prediction","confidence"])

# outputs["prediction"] = None
# outputs["confidence"] = None

In [105]:
from tqdm import tqdm

# predictions = []
# accuracies = []

for inputs in tqdm(dataset):
    inputs = inputs.to('cuda')

    with torch.no_grad():
        outputs = model(inputs)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted_classes = torch.max(probabilities, dim=1)
    for p,a in zip(predicted_classes.cpu().numpy(),confidence.cpu().numpy() * 100):
        df2.loc[len(df2)] = [p,a]

    # predictions.extend(predicted_classes.cpu().numpy())
    # accuracies.extend(confidence.cpu().numpy() * 100)

# df["prediction"] = predictions
# df["confidence"] = accuracies

 78%|███████▊  | 25/32 [01:10<00:20,  2.87s/it]

In [None]:

df2.to_csv("updated_dataset_7000_8000.csv", index=False)