# Cortx Pytorch Integration

This project builds and demonstrates a pytorch integration. The project uses the Cortx framework to store some data. A Pytorch dataloader is created to download the data and use it for a classification task.

We use the Medical MNIST dataset from Kaggle available at https://www.kaggle.com/andrewmvd/medical-mnist

In [1]:
import os
import sys
import boto3
import logging
import shutil
import threading
from botocore.client import Config
from matplotlib import pyplot as plt
from botocore.exceptions import ClientError
from boto3.s3.transfer import TransferConfig

In [2]:
class ProgressPercentage(object):
    def __init__(self, filename):
        self._filename = filename
        self._size = float(os.path.getsize(filename))
        self._seen_so_far = 0
        self._lock = threading.Lock()
        
    def __call__(self, bytes_amount):
        # To simplify, assume this is hooked up to a single filename
        with self._lock:
            self._seen_so_far += bytes_amount
            percentage = (self._seen_so_far / self._size) * 100
            sys.stdout.write("\r%s  %s / %s  (%.2f%%)" %
                             (self._filename, self._seen_so_far, 
                              self._size, percentage))
            sys.stdout.flush()

def multipart_upload_with_s3(bucket_name, file_path=None, object_name=None):
    # Multipart upload (see notes)
    config = TransferConfig(multipart_threshold=1024 * 25, max_concurrency=10,
                            multipart_chunksize=1024 * 25, use_threads=True)
    key_path = 'multipart_files/{}'.format(object_name)
    print(bucket_name,file_path,object_name,key_path)
    s3_client.upload_file(file_path, bucket_name, key_path,
                          ExtraArgs={'ACL': 'public-read', 
                                     'ContentType': 'text/pdf'},
                          Config=config, Callback=ProgressPercentage(file_path))
    
def multipart_download_with_s3(bucket_name, file_path=None, object_name=None):
    config = TransferConfig(multipart_threshold=1024 * 25, max_concurrency=10,
                            multipart_chunksize=1024 * 25, use_threads=True)
    temp_file = os.path.dirname(__file__)
    s3_resource.Object(bucket_name, 
                       object_name
                       ).download_file(file_path, Config=config,
                                       Callback=ProgressPercentage(temp_file))

"""Functions for buckets operation"""
def create_bucket_op(bucket_name, region):
    if region is None:
        s3_client.create_bucket(Bucket=bucket_name)
    else:
        location = {'LocationConstraint': region}
        s3_client.create_bucket(Bucket=bucket_name, 
                                CreateBucketConfiguration=location)

def list_bucket_op(bucket_name, region, operation):
    buckets = s3_client.list_buckets()
    if buckets['Buckets']:
        for bucket in buckets['Buckets']:
            print(bucket)
            return True
    else:
        logging.error('unknown bucket operation')
        return False
    
def bucket_operation(bucket_name, region=None, operation='list'):
    try:
        if operation == 'delete':
            s3_client.delete_bucket(Bucket=bucket_name)
        elif operation == 'create':
            create_bucket_op(bucket_name, region)
        elif operation == 'list':
            return list_bucket_op(bucket_name, region, operation)
        else:
            logging.error('unknown bucket operation')
            return False
    except ClientError as e:
        logging.error(e)
        return False
    return True

"""Functions for objects operation"""
def list_object_op(bucket_name):
     s3_objects = s3_client.list_objects_v2(Bucket=bucket_name)
     if s3_objects.get('Contents'):
         for obj in s3_objects['Contents']:
             print(obj)

def delete_object_op(bucket_name, object_name, operation):
    if not object_name:
        logging.error('object_name missing for {}'.format(operation))
        return False  
    s3_client.delete_object(Bucket=bucket_name, Key=object_name)
    return True

def upload_download_object_op(bucket_name, object_name, file_path, operation):
    if not file_path or not object_name:
        logging.error('file_path and/or object_name missing for upload')
        return False
    if operation == 'upload':
        multipart_upload_with_s3(bucket_name=bucket_name, file_path=file_path,
                                 object_name=object_name)
    else:
        multipart_download_with_s3(bucket_name=bucket_name, file_path=file_path,
                                   object_name=object_name)
    return True

def object_operation(bucket_name=None, object_name=None, file_path=None,
                     operation='list'):                                                             
    try:
        if not bucket_name:
            logging.error('The bucket name %s is missing for %s operation!'
                          % (bucket_name, operation))
            return False
        if operation == 'list':
            list_object_op(bucket_name)
        elif operation == 'delete':
            return delete_object_op(bucket_name, object_name, operation)
        elif operation == 'upload' or operation == 'download':
            return upload_download_object_op(bucket_name, object_name,
                                             file_path, operation)      
        else:
            logging.error('unknown object operation')
            return False
    except ClientError as e:
        logging.error(e)
        return False
    return True

"""Functions for files operation"""
def list_op_file(bucket_name):
    current_bucket = s3_resource.Bucket(bucket_name)
    print('The files in bucket %s:\n' % (bucket_name))
    for obj in current_bucket.objects.all():
        print(obj.meta.data) 
        
    return True

def delete_op_file(bucket_name, file_name, operation):
    if not file_name:
        logging.error('The file name %s is missing for%s operation!' 
                      % (file_name, operation))
        return False
    s3_client.delete_object(Bucket=bucket_name, Key=file_name)
    return True
    
def upload_download_op_file(bucket_name, file_name, file_location,
                            region, operation):
    if not file_location:
        logging.error('The file location %d is missing for %s operation!'
                      % (file_location, operation))
        return False
    if operation == 'download':
        s3_resource.Bucket(bucket_name).download_file(file_name, file_location)
    elif operation == 'upload' and region is None:
        s3_resource.Bucket(bucket_name).upload_file(file_location, file_name)
    else:
         location = {'LocationConstraint': region}
         s3_resource.Bucket(bucket_name
                            ).upload_file(file_location, file_name,
                                          CreateBucketConfiguration=location) 
    return True
    
def file_operation(bucket_name=None, file_name=None, file_location=None, 
                   region=None, operation='list'):
    if not bucket_name:
        logging.error('The bucket name is %s missing!' % (bucket_name))
        return False 
    try:
        if operation == 'list':
            return list_op_file(bucket_name)
        elif operation == 'delete':
            return delete_op_file(bucket_name, file_name, operation)  
        elif operation == 'upload' or operation == 'download':
            return upload_download_op_file(bucket_name, file_name, 
                                           file_location, region, operation)
        else:
            logging.error('unknown file operation')
            return False  
    except ClientError as e:
        logging.error(e)
        return False
    return True

In [3]:
END_POINT_URL = 'http://uvo17qqh4jn92xmchpj.vm.cld.sr'
A_KEY = 'AKIAtEpiGWUcQIelPRlD1Pi6xQ'
S_KEY = 'YNV6xS8lXnCTGSy1x2vGkmGnmdJbZSapNXaSaRhK'

In [4]:
s3_resource = boto3.resource('s3', endpoint_url=END_POINT_URL,
                             aws_access_key_id=A_KEY,
                             aws_secret_access_key=S_KEY,
                             config=Config(signature_version='s3v4'),
                             region_name='US')

s3_client = boto3.client('s3', endpoint_url=END_POINT_URL,
                         aws_access_key_id=A_KEY,
                         aws_secret_access_key=S_KEY,
                         config=Config(signature_version='s3v4'),
                         region_name='US')

In [6]:
list_op_file('headct')

The files in bucket headct:

{'Key': '003332.jpeg', 'LastModified': datetime.datetime(2021, 4, 27, 23, 50, 34, tzinfo=tzutc()), 'ETag': '"520d8ea45f112a33f5f97d39b8df9e79"', 'Size': 1045, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '003424.jpeg', 'LastModified': datetime.datetime(2021, 4, 27, 23, 50, 32, tzinfo=tzutc()), 'ETag': '"b41954570a6f85f5e42affa9740791a2"', 'Size': 1344, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '003904.jpeg', 'LastModified': datetime.datetime(2021, 4, 27, 23, 50, 37, tzinfo=tzutc()), 'ETag': '"71122c5c904fb9c58a780c9eb61ede41"', 'Size': 1219, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '004277.jpeg', 'LastModified': datetime.datetime(2021, 4, 27, 23, 50, 35

True

## Create MEDMINST data buckets on Cortx

In [7]:
class_names = ['abdomenct', 'breastmri', 'chestct', 'cxr', 'hand', 'headct']

In [33]:
for class_name in class_names:
    bucket_operation(class_name, None, 'create')
    print('created bucket for class: {}'.format(class_name))


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: abdomenct


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: breastmri


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: chestct


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: cxr


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: hand


ERROR:root:An error occurred (BucketAlreadyOwnedByYou) when calling the CreateBucket operation: The bucket you tried to create already exists, and you own it.


created bucket for class: headct


In [8]:
bucket_operation(None, None, operation='list')

{'Name': 'abdomenct', 'CreationDate': datetime.datetime(2021, 4, 27, 23, 2, 19, tzinfo=tzutc())}


True

In [9]:
dataset_base_path = '/home/dreamchild/sgcortex/data/medmnist/'

In [10]:
local_class_dirs = os.listdir(dataset_base_path)

In [11]:
local_class_dirs

['Hand', 'CXR', 'AbdomenCT', 'BreastMRI', 'HeadCT', 'ChestCT']

In [12]:
local_class_dir_files = {}

for class_dir in local_class_dirs:
    class_dir_files = []
    for file in os.listdir(dataset_base_path + class_dir)[:10]:
        class_dir_files.append(file)
    local_class_dir_files[class_dir] = class_dir_files

In [13]:
local_class_dir_files

{'Hand': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '009181.jpeg',
  '009386.jpeg',
  '003332.jpeg',
  '004277.jpeg',
  '007914.jpeg',
  '003904.jpeg'],
 'CXR': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '009181.jpeg',
  '009386.jpeg',
  '003332.jpeg',
  '004277.jpeg',
  '007914.jpeg',
  '003904.jpeg'],
 'AbdomenCT': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '009181.jpeg',
  '009386.jpeg',
  '003332.jpeg',
  '004277.jpeg',
  '007914.jpeg',
  '003904.jpeg'],
 'BreastMRI': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '003332.jpeg',
  '004277.jpeg',
  '007914.jpeg',
  '003904.jpeg',
  '000689.jpeg',
  '005112.jpeg'],
 'HeadCT': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '009181.jpeg',
  '009386.jpeg',
  '003332.jpeg',
  '004277.jpeg',
  '007914.jpeg',
  '003904.jpeg'],
 'ChestCT': ['005383.jpeg',
  '008405.jpeg',
  '003424.jpeg',
  '006537.jpeg',
  '009181.jpeg'

### Upload (10) files from each category into cortx

In [14]:
for k in local_class_dir_files.keys():
    for file in local_class_dir_files[k]:
        upload_file_path = dataset_base_path + k + '/' + file
        file_operation(k.lower(), file, upload_file_path, None, 'upload')
        print('uploaded file {} to bucket {}'.format(upload_file_path, k.lower()))

uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/005383.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/008405.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/003424.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/006537.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/009181.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/009386.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/003332.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/004277.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/007914.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/Hand/003904.jpeg to bucket hand
uploaded file /home/dreamchild/sgcortex/data/medmnist/CXR/005383.jpeg to bucket cxr
uploaded file /home/dreamchild/sgcortex/data/medmnist/CX

### Setup Pytorch dataloader

In [15]:
file_operation('hand',operation='list')

The files in bucket hand:

{'Key': '003332.jpeg', 'LastModified': datetime.datetime(2021, 4, 28, 22, 7, tzinfo=tzutc()), 'ETag': '"7aab34fe6035251e8bd656c28ce96ef6"', 'Size': 1357, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '003424.jpeg', 'LastModified': datetime.datetime(2021, 4, 28, 22, 6, 57, tzinfo=tzutc()), 'ETag': '"f9207043d0836eac5d7a282a4fffa152"', 'Size': 1555, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '003904.jpeg', 'LastModified': datetime.datetime(2021, 4, 28, 22, 7, 2, tzinfo=tzutc()), 'ETag': '"1e1331b574296a60aff80e0247ba9f7b"', 'Size': 1824, 'StorageClass': 'STANDARD', 'Owner': {'DisplayName': 'S3user', 'ID': 'e500ea6b45f64f068ab001b7f1fdfc57ed2faed247474d81b66f69a6233727c8'}}
{'Key': '004277.jpeg', 'LastModified': datetime.datetime(2021, 4, 28, 22, 7, 1, tzinfo=tzu

True

## CortxDataLoader

Download datasets from remote CORTX instance. Each class file is stored in a different bucket.

In [57]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True
import random

In [36]:
random.random()

0.8693079292631422

In [58]:
img_transforms = transforms.Compose([
    transforms.Resize((64,64)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225] )
    ])


## Introducing the CORTX Dataloader

The CORTX Dataloader downloads datasets from your CORTX instance. It automatically sorts your data into a training and testing set using an 80 percent split

In [74]:
class CortxDataLoader:
    
    def __init__(self, endpoint = None, accessKey = None, serviceKey = None, class_buckets = [], batch_size = 16):
        self.endpoint = endpoint
        self.accessKey = accessKey
        self.serviceKey = serviceKey
        self.class_buckets = class_buckets
        self.batch_size = batch_size
        
    def download_dataset(self, destination_dir = None):
        for bucket in self.class_buckets:
            s3_bucket = s3_resource.Bucket(bucket)
            print('Downloading files for class bucket: {}'.format(bucket))
            #files_in_bucket = s3_resource.Bucket(bucket).objects.all()
            folders = ['train','validation','test']
            current_bucket_objects = s3_bucket.objects.all()
            for fileObj in current_bucket_objects:       
                rand = random.random()
                filename = fileObj.key
                if not os.path.exists('/home/dreamchild/sgcortex/data/medmnist_dl/train/' + bucket):                    
                    os.makedirs('/home/dreamchild/sgcortex/data/medmnist_dl/train/' + bucket)
                    os.makedirs('/home/dreamchild/sgcortex/data/medmnist_dl/val/' + bucket)
                folder = "train" if rand < 0.8 else "val"
                with open('/home/dreamchild/sgcortex/data/medmnist_dl/'+ folder + '/' + bucket + '/' + filename, 'wb') as f:
                    s3_bucket.download_fileobj(filename, f)
                    print('downloaded file: {}'.format(filename))

    def train_data_loader(self):
        train_data_path = '/home/dreamchild/sgcortex/data/medmnist_dl/train/'
        train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms)
        train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=self.batch_size)
        return train_data_loader
    
    def test_data_loader(self):
        test_data_path = '/home/dreamchild/sgcortex/data/medmnist_dl/test/'
        test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms)
        test_data_loader = torch.utils.data.DataLoader(test_data, batch_size = self.batch_size)
        return test_data_loader
    
    def val_data_loader(self):
        val_data_path = '/home/dreamchild/sgcortex/data/medmnist_dl/val/'
        val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms)
        val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=self.batch_size)
        return val_data_loader

Initialize a new CortxDataLoader with different buckets:

In [75]:
cortx_loader = CortxDataLoader(class_buckets=['hand','breastmri','chestct','abdomenct','cxr','headct'])

Download data to a local destination folder:

In [76]:
cortx_loader.download_dataset('/home/dreamchild/sgcortex/data/medmnist-dl')

Downloading files for class bucket: hand
downloaded file: 003332.jpeg
downloaded file: 003424.jpeg
downloaded file: 003904.jpeg
downloaded file: 004277.jpeg
downloaded file: 005383.jpeg
downloaded file: 006537.jpeg
downloaded file: 007914.jpeg
downloaded file: 008405.jpeg
downloaded file: 009181.jpeg
downloaded file: 009386.jpeg
Downloading files for class bucket: breastmri
downloaded file: 000689.jpeg
downloaded file: 003332.jpeg
downloaded file: 003424.jpeg
downloaded file: 003904.jpeg
downloaded file: 004277.jpeg
downloaded file: 005112.jpeg
downloaded file: 005383.jpeg
downloaded file: 006537.jpeg
downloaded file: 007914.jpeg
downloaded file: 008405.jpeg
Downloading files for class bucket: chestct
downloaded file: 003332.jpeg
downloaded file: 003424.jpeg
downloaded file: 003904.jpeg
downloaded file: 004277.jpeg
downloaded file: 005383.jpeg
downloaded file: 006537.jpeg
downloaded file: 007914.jpeg
downloaded file: 008405.jpeg
downloaded file: 009181.jpeg
downloaded file: 009386.jpeg

## Setup Pytorch

A simple Neural Network:

In [77]:
class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50,2)
    
    def forward(self, x):
        x = x.view(-1, 12288)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [78]:
simplenet = SimpleNet()

Setup the optimizer

In [79]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

Use GPU if available else use CPU

In [68]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

simplenet.to(device)

SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

In [69]:
torch.cuda.is_available()

True

Setup training loop

In [80]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

## Run the training using the Cortx Dataloader

In [None]:
train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), cortx_loader.train_data_loader(), cortx_loader.val_data_loader(), epochs=5, device="cpu")