In [6]:
import torch
import torchvision
import psycopg2
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm.auto import tqdm
import io
from PIL import Image
from sklearn.metrics import roc_curve, auc
from torch.utils.data import Dataset
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import pickle


In [7]:
torch.cuda.empty_cache()
batch_size=32
epochs=15
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [8]:
transformation = transforms.Compose(
    [    
        transforms.Resize((256, 256)),
        transforms.RandomRotation(20),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

base_transform = torchvision.transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

In [9]:

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    def __getitem__(self, index):
        img = self.data[index]
        if self.transform:
            img = self.transform(img)
        label = self.labels[index]
        return img, label

    def __len__(self):
        return len(self.data)

In [10]:

conn_select = psycopg2.connect(
    database="hse_medical",
    user='hse_medical',
    password='123456',
    host='127.0.0.1',
    port='5450',
    options="-c search_path=analyze_medical"
)

conn_select.autocommit = True

transform = transforms.ToTensor()

def get_connection():
    return conn_select

In [11]:


cursor = conn_select.cursor()

sql1 = f'''select 
    target,
    image from medical_pictures_train;'''
cursor.execute(sql1)
data_postgres = cursor.fetchall()
print(f"datatrain_size : {len(data_postgres)}")
cursor.close()
targets = []
images = []
for data in data_postgres:
    targets.append(data[0])
    image = Image.open(io.BytesIO(data[1]))
    images.append(image)
    

data_train = CustomDataset(images, targets,transform=transformation)

datatrain_size : 15588


In [12]:
cursor_test= conn_select.cursor()

sql2 = f'''select 
    target,
    image from medical_pictures_test;'''
cursor_test.execute(sql2)
data_postgres_test = cursor_test.fetchall()
cursor_test.close()
targets_test = []
images_test = []


for data in data_postgres_test:
    targets_test.append(data[0])
    bytes_io = io.BytesIO(data[1])
    image_open = Image.open(bytes_io)
    images_test.append(image_open)
    
    

data_test = CustomDataset(images_test, targets_test,transform=base_transform)

In [13]:
train_dataloader = torch.utils.data.DataLoader(
    data_train, batch_size=batch_size, shuffle=True, num_workers=4
)
val_dataloader = torch.utils.data.DataLoader(
    data_test, batch_size=batch_size, shuffle=False, num_workers=4
)

In [14]:

vgg19 = torchvision.models.vgg19(pretrained=True)

vgg19.classifier[6] = nn.Linear(4096, 23)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg19.parameters(), lr=0.001, momentum=0.9)




In [15]:


for epoch in range(epochs):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(tqdm(train_dataloader), 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = vgg19(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:  
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')
with open('vgg19.pickle', 'wb') as f:
    pickle.dump(vgg19, f)

  0%|          | 0/488 [00:00<?, ?it/s]

[1,   100] loss: 2.795
[1,   200] loss: 2.558
[1,   300] loss: 2.504
[1,   400] loss: 2.395


  0%|          | 0/488 [00:00<?, ?it/s]

[2,   100] loss: 2.230
[2,   200] loss: 2.206
[2,   300] loss: 2.189
[2,   400] loss: 2.162


  0%|          | 0/488 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
Traceback (most recent call last):
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
Traceback (most recent call last):
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if

[3,   100] loss: 2.057
[3,   200] loss: 2.005
[3,   300] loss: 2.004
[3,   400] loss: 1.971


  0%|          | 0/488 [00:00<?, ?it/s]

[4,   100] loss: 1.852
[4,   200] loss: 1.829
[4,   300] loss: 1.828
[4,   400] loss: 1.834


  0%|          | 0/488 [00:00<?, ?it/s]

[5,   100] loss: 1.720
[5,   200] loss: 1.655
[5,   300] loss: 1.676
[5,   400] loss: 1.694


  0%|          | 0/488 [00:00<?, ?it/s]

[6,   100] loss: 1.530
[6,   200] loss: 1.555
[6,   300] loss: 1.574
[6,   400] loss: 1.529


  0%|          | 0/488 [00:00<?, ?it/s]

[7,   100] loss: 1.411
[7,   200] loss: 1.390
[7,   300] loss: 1.451
[7,   400] loss: 1.480


  0%|          | 0/488 [00:00<?, ?it/s]

[8,   100] loss: 1.306
[8,   200] loss: 1.266
[8,   300] loss: 1.289
[8,   400] loss: 1.326


  0%|          | 0/488 [00:00<?, ?it/s]

[9,   100] loss: 1.206
[9,   200] loss: 1.201
[9,   300] loss: 1.216
[9,   400] loss: 1.211


  0%|          | 0/488 [00:00<?, ?it/s]

[10,   100] loss: 1.089
[10,   200] loss: 1.085
[10,   300] loss: 1.094
[10,   400] loss: 1.056


  0%|          | 0/488 [00:00<?, ?it/s]

[11,   100] loss: 0.935
[11,   200] loss: 0.974
[11,   300] loss: 0.980
[11,   400] loss: 0.985


  0%|          | 0/488 [00:00<?, ?it/s]

[12,   100] loss: 0.901
[12,   200] loss: 0.845
[12,   300] loss: 0.894
[12,   400] loss: 0.933


  0%|          | 0/488 [00:00<?, ?it/s]

[13,   100] loss: 0.786
[13,   200] loss: 0.804
[13,   300] loss: 0.821
[13,   400] loss: 0.817


  0%|          | 0/488 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
Traceback (most recent call last):
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
Traceback (most recent call last):
Exception ignored in: if w.is_alive():
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
      File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
        self._shutdown_workers()self._

[14,   100] loss: 0.692
[14,   200] loss: 0.751
[14,   300] loss: 0.735
[14,   400] loss: 0.740


  0%|          | 0/488 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50><function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f26d89d7b50>
Traceback (most recent call last):

Traceback (most recent call last):
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
Traceback (most recent call last):
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
        self._shutdown_workers()    
self._shutdown_workers()self._shutdown_workers()  File "/home/roman/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers

Exception ignored in:     
  File "/home/roman/.local/lib/python3.10/site-packages/to

[15,   100] loss: 0.620
[15,   200] loss: 0.636
[15,   300] loss: 0.694
[15,   400] loss: 0.632
Finished Training


In [34]:
from sklearn.metrics import roc_auc_score

correct = 0
total = 0
all_labels = []
all_predicted_probs = []
y_true = []
y_pred = []
predicted_values = []
with torch.no_grad():
    for data in tqdm(val_dataloader):
        images, labels = data
        outputs = vgg19(images)
        _, predicted = torch.max(outputs, 1)
        predicted_values.extend(predicted.cpu().numpy())
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(probabilities.cpu().numpy())
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))





  0%|          | 0/126 [00:00<?, ?it/s]

Accuracy of the network on the test images: 50 %


In [35]:
from sklearn.metrics import classification_report, precision_score

y_true = torch.tensor(y_true)
y_pred = torch.tensor(y_pred)

roc_auc = roc_auc_score(y_true, y_pred, multi_class='ovr')

target_names = ['1', '2', '3', '4','5',
                '6','7','8','9','10','11','12',
                '13','14','15','16','17',
                '18','19','20','21','22','23']
print("ROC-AUC: {:.2f}%".format(roc_auc * 100))
precision = precision_score(y_true,predicted_values,average='macro')

print("precision : {:.2f}%".format(precision * 100))



ROC-AUC: 89.80%
precision : 48.72%
