In this notebook, we try to

- train classificator on real 3D MRI brain images and corresponding gender labels 
- after that, we use previously synthesized images to predict gender labels for them
- finally, we train a simple logistic regression on latent representations of synthesized images to fit labels, predicted on previous step

Then, for a real image from initial dataset, we can get its modification which would be classified more as man or a woman. With that, patterns connected to this data fluctuation would be visible.

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
import torch.utils.data as torch_data
from torchvision.utils import save_image
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from skimage.transform import resize
from tqdm import tqdm

In [3]:
#from google.colab import drive
#drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
class MRIData(torch_data.Dataset):
    def __init__(self, X, y):
        super(MRIData, self).__init__()
        self.X = X
        self.y = y
    
        data_min = self.X.min(axis=(2, 3, 4))
        data_max = self.X.max(axis=(2, 3, 4))
        data_max[data_max == 0.] = 1.
        dif = data_max - data_min

        self.X = (self.X - data_min[..., None, None, None]) / dif[..., None, None, None]

    def __len__(self):

        return len(self.X)
    
    def __getitem__(self, idx):

        return self.X[idx], self.y[idx]

In [5]:
mri_data = np.load('tensors.npy')
labels = np.load('sex.npy')

In [6]:
torch.manual_seed(0)

data_train, data_val, y_train, y_val = train_test_split(mri_data, labels,
                                                        test_size=0.2)

data_train = np.pad(data_train, ((0, 0), (6, 6), (0, 0), (6, 6)))
data_train = data_train.reshape(-1, 1, 70, 70, 70)

data_val = np.pad(data_val, ((0, 0), (6, 6), (0, 0), (6, 6)))
data_val = data_val.reshape(-1, 1, 70, 70, 70)

In [7]:
batch_size = 8

train_dset = MRIData(data_train, y_train)
test_dset = MRIData(data_val, y_val)

train_loader = torch_data.DataLoader(train_dset, batch_size=10, shuffle=True)
test_loader = torch_data.DataLoader(test_dset, batch_size=8) 

In [8]:
class ResBlock(nn.Module):

    def __init__(self, in_channels, out_channels, use_bn=True, padding=1, kernel_size=3):
        super(ResBlock, self).__init__()

        if use_bn:
          self.bn1 = nn.BatchNorm3d(in_channels)
          self.bn2 = nn.BatchNorm3d(in_channels // 2)
        else:
          self.bn1 = Identity()
          self.bn2 = Identity()
        
        self.conv0 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.conv1 = nn.Conv3d(in_channels, in_channels // 2, kernel_size=3, padding=padding)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv3d(in_channels // 2, out_channels, kernel_size=3, padding=padding)


    def forward(self, x):
        identity = self.conv0(x)

        out = self.bn1(x)
        out = self.relu(x)
        out = self.conv1(x)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out += identity

        return out

In [9]:
class Flatten(torch.nn.Module):

    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)
    

class FC_Classifier(nn.Module):
    
    def __init__(self, in_channels, n_classes):
        super(FC_Classifier, self).__init__()
        
        n = 8
        
        self.conv1 = nn.Conv3d(in_channels, n * 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(n * 2)
        self.conv2 = ResBlock(n * 2, n * 4, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(n * 4)
        self.conv3 = ResBlock(n * 4, n * 8, kernel_size=3, padding=1)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool3d(2)
        self.dropout = nn.Dropout3d()
        self.flatten = Flatten()
        self.linear1 = nn.Linear((8 ** 3) * n * 8, 200)
        self.bn3 = nn.BatchNorm1d(200)
        self.linear2 = nn.Linear(200, n_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        # x = self.bn1(x)
        # x = self.act(x)
        x = self.pool(x)

        x = self.conv2(x)
        # x = self.bn2(x)
        # x = self.act(x)
        x = self.pool(x)

        x = self.conv3(x)
        # x = self.act(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.linear1(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.linear2(x)
        
        return x

In [10]:
classifier = FC_Classifier(in_channels=1, n_classes=1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
classifier.to(device)

criterion = nn.BCEWithLogitsLoss()  
optimizer = torch.optim.Adam(classifier.parameters(), lr=5e-3) 

In [11]:
def train_classifier(classifier, criterion, optimizer, train_loader, test_loader, n_epochs=50):
    
    for epoch in range(1, n_epochs+1):
        classifier.train()

        losses_train = []
        
        for X, target in train_loader:
            if len(X) > 1:
              X, target = X.to(device), target.to(device)
              target = target.float().reshape(-1, 1)
              
              
              optimizer.zero_grad()
              
              logits = classifier(X)
              loss = criterion(logits, target)
              losses_train.append(loss.detach().cpu().numpy())

              loss.backward()
              optimizer.step()
        
        if epoch % 5 == 0:
            
            y_pred_val =  []
            y_true_val = []

            classifier.eval()
            
            losses_val = []
            
            for X, target in test_loader:
                X, target = X.to(device), target.to(device)
                target = target.float().reshape(-1, 1)

                logits = classifier(X)
                val_loss = criterion(logits, target)
                losses_val.append(val_loss.detach().cpu().numpy())
                
                target_hat_val = torch.sigmoid(logits)
                target_hat_val = target_hat_val > 0.5
                target_hat_val = target_hat_val.float()

                y_pred_val.extend(target_hat_val.tolist())
                y_true_val.extend(target.tolist())

            mean_val = sum(losses_val) / len(losses_val)
            mean_train = sum(losses_train) / len(losses_train)

            print('Val epoch {}'.format(epoch), \
              ', Loss : {:.3}'.format(mean_train), \
              ', Accuracy on test: {:.3}'.format(accuracy_score(y_true_val, y_pred_val)) )
            
            torch.save(classifier.state_dict(), 'gender_classifier_checkpoint.pth')


In [12]:
 train_classifier(classifier, criterion, optimizer, train_loader, test_loader, n_epochs=50)

Val epoch 5 , Loss : 0.355 , Accuracy on test: 0.879
Val epoch 10 , Loss : 0.319 , Accuracy on test: 0.937
Val epoch 15 , Loss : 0.246 , Accuracy on test: 0.942
Val epoch 20 , Loss : 0.194 , Accuracy on test: 0.919
Val epoch 25 , Loss : 0.174 , Accuracy on test: 0.969
Val epoch 30 , Loss : 0.154 , Accuracy on test: 0.969
Val epoch 35 , Loss : 0.0972 , Accuracy on test: 0.964
Val epoch 40 , Loss : 0.0972 , Accuracy on test: 0.955
Val epoch 45 , Loss : 0.134 , Accuracy on test: 0.928
Val epoch 50 , Loss : 0.0766 , Accuracy on test: 0.946


Load all previously trained models

In [13]:
from models import Generator3D_Adaptive, Encoder
from utils import plot_central_cuts, Fake_MRIData

In [14]:
generator = Generator3D_Adaptive()
encoder = Encoder()
classifier = FC_Classifier(in_channels=1, n_classes=1)

generator.load_state_dict(torch.load('generator_checkpoint.pth'))
encoder.load_state_dict(torch.load('encoder_checkpoint.pth'))
classifier.load_state_dict(torch.load('gender_classifier_checkpoint.pth'))


generator.eval()
encoder.eval()
classifier.eval();
#generator.to(device)
#encoder.to(device)
#classifier.to(device);

Load images, synthesized by generator, and its corresponding latent representations

In [15]:
fake_data = np.load('fake_data.npy')
latent_repr = np.load('latent_repr.npy')

In [16]:
fake_dset = Fake_MRIData(fake_data, latent_repr)

fake_data_loader = torch_data.DataLoader(fake_dset, batch_size=10, shuffle=False)

Predict genders for synthesized images

In [17]:
fake_gender_labels = np.array([])

for imgs, _ in fake_data_loader:
  logits = classifier(imgs)

  gender_labels = torch.sigmoid(logits)
  gender_labels = gender_labels > 0.5
  gender_labels = gender_labels.float()

  fake_gender_labels = np.concatenate((fake_gender_labels, gender_labels.detach().numpy().reshape(-1)))


In [36]:
np.save('fake_gender_labels.npy', fake_gender_labels)

In [18]:
fake_gender_labels.max()

0.0

Unfortunately, it looks like fake images, synthesized by generator, are too poor quality for the classifier to retrieve necessary features from them. Though many efforts were applied to train a stable generator, there should be a lot more precise tuning of both models to work together well, which is impossible for our time frames    