In [1]:
#These dataloaders are taken from FATE
from dataset.table import TableDataset
from dataset.image import ImageDataset

In [2]:
guest_train_path = '../examples/data/oasis2train/oasis2.csv'
host_train_path = '../examples/data/oasis2train/mri/'

guest_train = TableDataset()
guest_train.load(guest_train_path)
host_train = ImageDataset(return_label=False)
host_train.load(host_train_path)

guest_val_path = '../examples/data/oasis2val/oasis2.csv'
host_val_path = '../examples/data/oasis2val/mri/'

guest_val = TableDataset()
guest_val.load(guest_val_path)
host_val = ImageDataset(return_label=False)
host_val.load(host_val_path)

guest_test_path = '../examples/data/oasis2test/oasis2.csv'
host_test_path = '../examples/data/oasis2test/mri/'

guest_test = TableDataset()
guest_test.load(guest_test_path)
host_test = ImageDataset(return_label=False)
host_test.load(host_test_path)

In [3]:
!ls ../examples/data/oasis2train/

mri  oasis2.csv


In [4]:
import torch as t
from torch import nn
from torch.nn import Module
import logging
import numpy as np

class BottomHost(nn.Module):

    def __init__(self):
        super(BottomHost, self).__init__()
        self.cuda = True
        self.seq = t.nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3),
            nn.MaxPool2d(kernel_size=3),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.AvgPool2d(kernel_size=3)
        )

        self.fc = t.nn.Sequential(   # extracted feature is a 8-dim embedding
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = x
        x = self.seq(x)
        x = x.mean((2,3))
        #x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

In [5]:
import torch as t
from torch import nn
from torch.nn import Module

class BottomGuest(nn.Module):

    def __init__(self):
        super(BottomGuest, self).__init__()
        self.fc = t.nn.Linear(12,10)

    def forward(self, x):
        x = self.fc(x)
        return x

In [6]:
import torch as t
from torch import nn
from torch.nn import Module

class TopNet(nn.Module):

    def __init__(self):
        super(TopNet, self).__init__()
        self.fc = t.nn.Sequential(   
            nn.Linear(10, 3),
            t.nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.fc(x)
        return x

In [7]:
import torch as t
from torch import nn
from torch.nn import Module

class InteractiveLayer(nn.Module):

    def __init__(self):
        super(InteractiveLayer, self).__init__()
        self.fc = t.nn.Linear(10,10)

    def forward(self, x):
        x = self.fc(x)
        return x

In [8]:
import torch as t
from torch import nn
from torch.nn import Module

class CompositeModel(nn.Module):

    def __init__(self):
        super(CompositeModel, self).__init__()
        self.cnn = BottomHost()
        self.fnn = BottomGuest()
        self.interactive = InteractiveLayer()
        self.top = TopNet()

    def forward(self, img, table):
        x = self.cnn(img) + self.fnn(table)
        x = self.interactive(x)
        x = self.top(x)
        return x
    
    def train(self, mode):
        self.cnn.train(mode=mode)
        self.fnn.train(mode=mode)
        self.interactive.train(mode=mode)
        self.top.train(mode=mode)

In [9]:
model = CompositeModel()

In [13]:
from torch.utils.data import Dataset, DataLoader

model = CompositeModel()

device = t.device("cuda" if t.cuda.is_available() else "cpu")
#device = 'cpu'
    
loss = t.nn.CrossEntropyLoss()
optimizer = t.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)

model = model.to(device)

num_epochs = 50

batch_size = 32

train_dataloader = zip(DataLoader(host_train, batch_size=batch_size, shuffle=False), DataLoader(guest_train, batch_size=batch_size, shuffle=False))
val_dataloader = zip(DataLoader(host_val, batch_size=batch_size, shuffle=False), DataLoader(guest_val, batch_size=batch_size, shuffle=False))
test_dataloader = zip(DataLoader(host_test, batch_size=batch_size, shuffle=False), DataLoader(guest_test, batch_size=batch_size, shuffle=False))

n_train = len(list(train_dataloader))
n_val = len(list(val_dataloader))
n_test = len(list(test_dataloader))


#train_dataloader = train_dataloader.to(device)
#val_dataloader = val_dataloader.to(device)
#test_dataloader = test_dataloader.to(device)
# Iterate through the DataLoader

train_loss_history = []
val_loss_history = []

for epoch in range(num_epochs):
    #Wonky fix, the BLUF is don't zip dataloaders.
    train_dataloader = zip(DataLoader(host_train, batch_size=batch_size, shuffle=False), DataLoader(guest_train, batch_size=batch_size, shuffle=False))
    val_dataloader = zip(DataLoader(host_val, batch_size=batch_size, shuffle=False), DataLoader(guest_val, batch_size=batch_size, shuffle=False))
    
    train_loss = 0
    val_loss = 0
    valid_accuracy = 0
    print(epoch)
    for img, (dem, lbl) in train_dataloader:
        img = img.to(device)
        dem = dem.to(device)
        lbl = lbl.to(device).long().squeeze()
        #lbl = t.nn.functional.one_hot(lbl.to(t.int64),3).to(device)
        
        model.train(mode=True)
        output = model(img, dem)
            
        l = loss(output, lbl)
            
        train_loss += l.detach()/n_train
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    
    print('train_loss: ' + str(train_loss))

    for img, (dem, lbl) in val_dataloader:
        img = img.to(device)
        dem = dem.to(device)
        lbl = lbl.to(device).long().squeeze()
        
        model.train(mode=False)
        output = model(img, dem)
        l = loss(output, lbl)
        val_loss += l.detach()/n_val
    
    print('val_loss: ' + str(val_loss))
    
    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)


0
train_loss: tensor(0.6449, device='cuda:0')
val_loss: tensor(0.6975, device='cuda:0')
1
train_loss: tensor(0.6487, device='cuda:0')
val_loss: tensor(0.6997, device='cuda:0')
2
train_loss: tensor(0.6413, device='cuda:0')
val_loss: tensor(0.6965, device='cuda:0')
3
train_loss: tensor(0.6453, device='cuda:0')
val_loss: tensor(0.7014, device='cuda:0')
4
train_loss: tensor(0.6529, device='cuda:0')
val_loss: tensor(0.7026, device='cuda:0')
5
train_loss: tensor(0.6651, device='cuda:0')
val_loss: tensor(0.7200, device='cuda:0')
6
train_loss: tensor(0.6731, device='cuda:0')
val_loss: tensor(0.7047, device='cuda:0')
7
train_loss: tensor(0.6392, device='cuda:0')
val_loss: tensor(0.6985, device='cuda:0')
8
train_loss: tensor(0.6324, device='cuda:0')
val_loss: tensor(0.6974, device='cuda:0')
9
train_loss: tensor(0.6312, device='cuda:0')
val_loss: tensor(0.6983, device='cuda:0')
10
train_loss: tensor(0.6305, device='cuda:0')
val_loss: tensor(0.6992, device='cuda:0')
11
train_loss: tensor(0.6297, d

KeyboardInterrupt: 

In [12]:
val_loss_history

[tensor(0.9526),
 tensor(0.9262),
 tensor(0.9129),
 tensor(0.8174),
 tensor(0.7534),
 tensor(0.7473),
 tensor(0.7650),
 tensor(0.9416),
 tensor(0.7704),
 tensor(0.7066),
 tensor(0.7088),
 tensor(0.7044),
 tensor(0.7046),
 tensor(0.7032),
 tensor(0.7026),
 tensor(0.7020),
 tensor(0.7011),
 tensor(0.6994),
 tensor(0.6982),
 tensor(0.6977)]