# Train

In [None]:
import os
import argparse
import sys
import json
import pickle
import random

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
iterTraverse the folders, one folder corresponds to one category
from torchvision import transforms

from tqdm import tqdm
import matplotlib.pyplot as plt

from model import swin_tiny_patch4_window7_224 as create_model

### 1. Divide the dataset

In [None]:

def read_split_data(root: str, val_rate: float = 0.1, test_rate: float = 0.1):
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # Traverse the folders, one folder corresponds to one category
    patient_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]

    patient_class.sort()
    # Generate category names and corresponding numeric indices
    class_indices = dict((k, v) for v, k in enumerate(patient_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    #patient_class = ['AD','CN']
    train_images_path = []  
    train_images_label = []  
    val_images_path = []  
    val_images_label = []  
    test_images_path = []  
    test_images_label = [] 
    every_class_num = []  
    supported = [".gz"]  
    
    # Traverse the files in each folder
    for cla in patient_class:
        cla_path = os.path.join(root, cla) # root/AD/
    
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        
        images.sort()
        
        image_class = class_indices[cla]
        
        every_class_num.append(len(images))
        
        val_test_path = random.sample(images, k=int(len(images) * 0.25))
        #test_path = random.sample(val_test_path, k=int(len(val_test_path) * 0.5))
        for img_path in images:
            if img_path in val_test_path:  
                val_images_path.append(img_path)
                val_images_label.append(image_class)
                
#                 if img_path in test_path:  # 如果该路径在测试集样本中，则存入测试集
#                     test_images_path.append(img_path)
#                     test_images_label.append(image_class)
#                 else:  # 否则存入验证集
#                     val_images_path.append(img_path)
#                     val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    print("{} images for test.".format(len(test_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."
    #assert len(test_images_path) > 0, "number of test images must greater than 0."
    plot_image = True
    if plot_image:
        
        plt.bar(range(len(patient_class)), every_class_num, align='center')
        
        plt.xticks(range(len(patient_class)), patient_class)
        
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        
        plt.xlabel('image class')
        
        plt.ylabel('number of images')
        
        plt.title('patient class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label
#, test_images_path, test_images_label


In [None]:
train_images_path, train_images_label, val_images_path, val_images_label=read_split_data("/dataset")


## 2. Define MONAI transforms and instantiate the dataset

In [None]:
import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai import transforms

# GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

data_transform = {
    "train": transforms.Compose([transforms.EnsureChannelFirst(),                            
                                 transforms.CropForeground(k_divisible=1),
                                 transforms.CenterSpatialCrop(roi_size=(112,121,64)), 
                                 transforms.RandSpatialCrop(roi_size=(80,90,44), max_roi_size=(-1,-1,-1), random_size=True), 
                                 transforms.Resize(spatial_size=(112,112,64)),  # resize
                                 transforms.NormalizeIntensity(),
                                 transforms.ScaleIntensity(),
                                 
                                 transforms.RandFlip(prob=0.5,spatial_axis=0),
                                 transforms.RandFlip(prob=0.5,spatial_axis=1),
                                 transforms.RandFlip(prob=0.5,spatial_axis=2),
                                  transforms.RandRotate90(prob=0.5, spatial_axes=(0, 1)),
#                                  transforms.RandRotate90(prob=0.5, spatial_axes=(1, 2)),
                                 transforms.ToTensor()  
                                ]),
    "val": transforms.Compose([transforms.EnsureChannelFirst(), 
                               
                               transforms.CropForeground(k_divisible=1),
                               transforms.CenterSpatialCrop(roi_size=(112,121,64)),
                               #transforms.RandSpatialCrop(roi_size=(112,112,64), max_roi_size=(-1,-1,-1)), 
                               #transforms.CenterSpatialCrop(roi_size=(112,112,64)),
                               transforms.Resize(spatial_size=(112,112,64)),
                               transforms.NormalizeIntensity(),
                               transforms.ScaleIntensity(),
                               #transforms.Resize(spatial_size=(112,112,64)),  # resize
                               transforms.ToTensor()  # 转为Tensor
                              ]),
    "test": transforms.Compose([transforms.EnsureChannelFirst(),  
                                transforms.RepeatChannel(repeats=3),  
                               transforms.CropForeground(k_divisible=1),
                               transforms.Resize(spatial_size=(96,96,96)),  # resize
                               transforms.ToTensor(),  # 转为Tensor
                                #transforms.NormalizeIntensity()
                               ])
}

# 实例化训练数据集
train_dataset = ImageDataset(image_files=train_images_path,
                          labels=train_images_label,
                          transform=data_transform["train"])

# 实例化验证数据集
val_dataset = ImageDataset(image_files=val_images_path,
                        labels=val_images_label,
                        transform=data_transform["val"])

# # 实例化测试数据集
# test_dataset = ImageDataset(image_files=test_images_path,
#                         labels=test_images_label,
#                         transform=data_transform["test"])


### 3. Loading dataLoader

In [None]:
batch_size = 2   # can change
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          pin_memory=True,
                          num_workers=nw)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        pin_memory=True,
                        num_workers=nw)
# test_loader = DataLoader(test_dataset,
#                          batch_size=batch_size,
#                          shuffle=False,
#                          pin_memory=True,
#                          num_workers=nw)
train_num = len(train_dataset)
val_num = len(val_dataset)
# test_num = len(test_dataset)
print("using {} images for training, {} images for validation".format(train_num, 
                                                                                          val_num))


### 4. Display a slice

In [None]:
import numpy as np
# image (B, C，D, H, W)
#       (1, 1, 64, 112, 112)
for k,data in enumerate(train_loader):
    data[0] = data[0].permute(0,1, 4, 3, 2)
    
    print(data[0].shape)
    image = data[0]
    label = data[1]
    data_img=np.array(image)
    label_img=np.array(label)
    print(data_img.shape)
    print(label_img[0])
    
    x_slice_data = data_img[0,0,30, :,:]
    plt.imshow(x_slice_data,cmap='gray')
    #plt.title(class_indict['0'])
    plt.colorbar()
    plt.show()
    break

### 5. Load the pre-training weight

In [None]:
## The pre-training parameters of the original Video Swin Transformer are available for download on github. 
## You can also not initialize the weights.
weights = './swin_tiny_patch244_window877_kinetics400_1k.pth'
predict_model = torch.load(weights)
#print(predict_model['state_dict'].keys())


predict_model['state_dict'] = {key.replace('backbone.', ''): value for key, value in predict_model['state_dict'].items()}
#print(predict_model['state_dict'].keys())


## Match the weight of the video swin Transformer. Do not use other weights.
for k in list(predict_model['state_dict'].keys()):
    if "patch_embed" in k:
        del predict_model['state_dict'][k]

In [None]:
model = create_model(num_classes=2, init_weights=True).to(device)
net_dict = model.state_dict()
#print(net_dict.keys())
#print(predict_model['state_dict'].items())

state_dict = {k: v for k, v in predict_model['state_dict'].items() if k in net_dict.keys()}

print(state_dict.keys())
net_dict.update(state_dict)  
model.load_state_dict(net_dict) 
print(model)

### 6. Train

In [None]:

loss_function = torch.nn.CrossEntropyLoss()
#pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(model.parameters(), lr=0.000005, weight_decay=1e-1) # 下一次用5e-3试试  verbose=true
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=400, eta_min=0, last_epoch=-1)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-7, last_epoch=-1, verbose=False)#2E-7
epochs = 550
best_acc = 0.0
best_epoch = 0
train_loss=[]
val_loss = []
tra_acc = []
val_acc=[]

# Save the weights after training
save_path = './files/xxxxxxxxxx.pth'
train_steps = len(train_loader)
val_steps = len(val_loader)
for epoch in range(epochs):
    # train
    model.train()
    
    running_loss = 0.0
    train_acc = 0
    train_bar = tqdm(train_loader, file=sys.stdout)  # 返回一个迭代器
    for step, data in enumerate(train_bar):
        images, labels = data
        # 修改输入
        images = images.permute(0,1, 4, 3, 2)
        #print(images.shape)
        optimizer.zero_grad()
        pred = model(images.to(device))
        loss = loss_function(pred, labels.to(device))
        loss.backward()
        optimizer.step()
        pred_train = torch.argmax(pred, dim=1)
        train_acc += torch.eq(pred_train, labels.to(device)).sum().item()
        # print statistics
        running_loss += loss.item()

        # Progress bar description
        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, running_loss/(step+1))
    scheduler.step()
    print(optimizer.state_dict()['param_groups'][0]['lr'])  # print lr
    train_loss.append(running_loss / len(train_loader))

    # validate
    #optimizer.zero_grad()
    model.eval()
    acc_loss = 0.0
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(val_loader, file=sys.stdout)
        for val_step, val_data in enumerate(val_bar):
            val_images, val_labels = val_data
            # 修改输入
            val_images = val_images.permute(0,1, 4, 3, 2)
            pred = model(val_images.to(device))
            # Calculation and verification loss
        
            loss = loss_function(pred, val_labels.to(device))
            acc_loss += loss.item()
            predict_y = torch.argmax(pred, dim=1)
            #print( pred)
            #print( predict_y)
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
           # print( val_labels.to(device))
    train_accurate = train_acc / train_num
    val_accurate = acc / val_num
    tra_acc.append(train_accurate)
    val_acc.append(val_accurate)
    val_loss.append(acc_loss / len(val_loader))
    print('[epoch %d] train_loss: %.3f  train_accuracy:%.3f  val_loss:%.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps,train_accurate, acc_loss / val_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        best_epoch = epoch
        torch.save(model.state_dict(), save_path)

print('Finished Training')


epochs = range(len(train_loss))
plt.plot(epochs, train_loss, 'g', label='Loss of Training data')
plt.plot(epochs, val_loss, 'b', label='Loss of Val data')
plt.plot(epochs, val_acc, 'r', label='Acc of Val data')
plt.plot(epochs, tra_acc, label='Acc of Train data')
plt.legend(loc=0)
plt.figure()






In [None]:
import numpy as np
print(best_acc)
print(best_epoch)
np.save('./files/train_loss1.npy', train_loss)
np.save('./files/val_loss1.npy', val_loss)
np.save('./files/val_acc1.npy', val_acc)
np.save('./files/tra_acc1.npy', tra_acc)

In [None]:
tra_ls = np.load('./files/train_loss1.npy')
val_ls = np.load('./files/val_loss1.npy')
tra_ac = np.load('./files/tra_acc1.npy')
val_ls = np.load('./files/val_acc1.npy')
epochs = range(len(train_loss))
plt.plot(epochs, train_loss, 'g', label='Loss of Training data')
plt.plot(epochs, val_loss, 'b', label='Loss of Val data')
plt.plot(epochs, val_acc, 'r', label='Acc of Val data')
plt.plot(epochs, tra_acc, label='Acc of Train data')
plt.legend(loc=0)
plt.figure()


In [None]:
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
    class_indict = json.load(f)

print(class_indict)
class_indict = {'0':'AD',
                 '1':'CN'}

### 7. Model evaluation

In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
import seaborn as sns


# load model weights
weights_path = './files/xxxxxxx.pth'
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path, map_location=device))

predictions = []  
labels = []  
model.eval()
with torch.no_grad():
    test_bar = tqdm(val_loader, file=sys.stdout)
    for test_data in test_bar:
        test_images, test_labels = test_data
        test_images = test_images.permute(0,1, 4, 3, 2)
        outputs = model(test_images.to(device))
        predict_y = torch.max(outputs, dim=1)[1].cpu()  
        predictions.extend(predict_y.numpy())
        labels.extend(test_labels.cpu().numpy())
        
predictions = np.array(predictions)
labels = np.array(labels)
# Print evaluation report
print(classification_report(labels, predictions, target_names=class_indict.values()))

    
    
  

### 8.confusion_matrix

In [None]:
cm = confusion_matrix(labels, predictions)
cm_df = pd.DataFrame(cm, index=class_indict.values(), columns=class_indict.values())
cm_df
plt.figure( dpi=300)
sns.heatmap(cm_df, annot=True, cmap="Blues", fmt=".1f")
plt.title("Confusion Matrix", fontweight="bold")
plt.xlabel("Predicted", fontweight="bold")
plt.ylabel("True", fontweight="bold")

# 