# Preprocess_brain_map

In [2]:
!pip install -q neuroquery
!pip install -q nilearn

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.3/10.3 MB[0m [31m76.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m80.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import neuroquery
import nilearn
import time
import os


import nibabel as nib
import numpy as np
import pandas as pd


from tqdm import tqdm
from joblib import Memory
from nilearn import image
from neuroquery._compat import maskers, load_mni152_brain_mask

coord2map

In [None]:
def get_masker(mask_img=None, target_affine=None):
    if isinstance(mask_img, maskers.NiftiMasker):
        return mask_img
    if mask_img is None:
        mask_img = load_mni152_brain_mask()
    if target_affine is not None:
        if np.ndim(target_affine) == 0:
            target_affine = np.eye(3) * target_affine
        elif np.ndim(target_affine) == 1:
            target_affine = np.diag(target_affine)
        mask_img = image.resample_img(
            mask_img, target_affine=target_affine, interpolation="nearest"
        )
    masker = maskers.NiftiMasker(mask_img=mask_img).fit()
    return masker


def coords_to_voxels(coords, ref_img=None):
    if ref_img is None:
        ref_img = load_mni152_brain_mask()
    affine = ref_img.affine
    coords = np.atleast_2d(coords)
    coords = np.hstack([coords, np.ones((len(coords), 1))])
    voxels = np.linalg.pinv(affine).dot(coords.T)[:-1].T
    voxels = voxels[(voxels >= 0).all(axis=1)]
    voxels = voxels[(voxels < ref_img.shape[:3]).all(axis=1)]
    voxels = np.floor(voxels).astype(int)
    return voxels


def coords_to_peaks_img(coords, mask_img):
    mask_img = image.load_img(mask_img)
    voxels = coords_to_voxels(coords, mask_img)
    peaks = np.zeros(mask_img.shape)
    np.add.at(peaks, tuple(voxels.T), 1.0)
    peaks_img = image.new_img_like(mask_img, peaks)
    return peaks_img


def gaussian_coord_smoothing(
    coords, mask_img=None, target_affine=None, fwhm=9.0
):
    masker = get_masker(mask_img, target_affine)
    peaks_img = coords_to_peaks_img(coords, mask_img=masker.mask_img_)
    img = image.smooth_img(peaks_img, fwhm=fwhm)
    return img


def coordinates_to_maps(
    coordinates, mask_img=None, target_affine=(4, 4, 4), fwhm=9.0
):
    masker = get_masker(mask_img=mask_img, target_affine=target_affine)
    images, img_pmids = [], []
    for pmid, img in iter_coordinates_to_maps(
        coordinates, mask_img=masker, fwhm=fwhm
    ):
        images.append(img)
        img_pmids.append(pmid)
    return images, masker


def iter_coordinates_to_maps(
    coordinates, mask_img=None, target_affine=(4, 4, 4), fwhm=9.0
):
    masker = get_masker(mask_img=mask_img, target_affine=target_affine)
    articles = coordinates.groupby("id")
    for i, (pmid, coord) in enumerate(articles):
        img = gaussian_coord_smoothing(
            coord.loc[:, ["x", "y", "z"]].values, fwhm=fwhm, mask_img=masker
        )
        yield pmid, img

main

In [None]:
if __name__=='__main__':
    cache_directory = "/disk1/wyn/workshop/ChatGPT/text2brain-main/cache/"
    out_dir = '/disk1/wyn/workshop/ChatGPT/text2brain-main/data/brain_maps/neuroquery/'

    neuroquery_coord = pd.read_table(
        '/disk1/wyn/workshop/ChatGPT/text2brain-main/data-neuroquery_version-1_coordinates.tsv.gz') # coordinates data
    neuroquery_meta = pd.read_table(
        '/disk1/wyn/workshop/ChatGPT/text2brain-main/data-neuroquery_version-1_metadata.tsv.gz') # artilces data
    mask_152 = nib.load('/disk1/wyn/workshop/ChatGPT/text2brain-main/mask_img.nii') # brain templates MNI152

    coord_to_maps = Memory(cache_directory).cache(coordinates_to_maps)

    for index in tqdm(neuroquery_meta.index):
        title = neuroquery_meta.at[index, 'title']
        id = int(neuroquery_meta.at[index, 'id'])

        brain_maps_dir = os.path.join(out_dir, str(id) + '.nii.gz') # save path for every brain map
        vol_data = np.zeros((46, 55, 46)) # coordinates space of MNI152 with edges
        affine = np.array([
            [4., 0., 0., -90.],
            [0., 4., 0., -126.],
            [0., 0., 4., -72.],
            [0., 0., 0., 1.]
        ])

        coord = neuroquery_coord.loc[neuroquery_coord['id'] == id]
        brain_maps, masker = coord_to_maps(
            coord, mask_img=mask_152, target_affine=(4, 4, 4), fwhm=9.0
        ) # transfer coordinates to maps
        brain_map = brain_maps[0]
        brain_map = masker.inverse_transform(masker.transform(brain_map).squeeze())
        brain_map = brain_map.get_fdata()

        vol_data[3:-3, 3:-4, :-6] = brain_map[3:-3, 3:-4, :-6]
        brain_img = nib.Nifti1Image(vol_data, affine)

        nib.save(brain_img, brain_maps_dir) # save as NibImage

# ChatAUG

In [4]:
!pip install -q openai

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/73.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.6/73.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [5]:
import openai
import os
import time


import numpy as np
import pandas as pd


from tqdm import tqdm

text2aug

In [None]:
def text2aug(text):
  max_tries = 3
  for try_number in range(max_tries):
    try:
      response = openai.ChatCompletion.create(
          model = 'gpt-3.5-turbo',
          messages = [
              {'role': 'system', 'content': 'You are a brain science expert and I will give you the title of a paper related to brain science next, please rewrite or expand it according to the needs I provide with the given title. Please adhere strictly to the example format.'},
              {'role': 'user', 'content': 'Please return two rewrites of the title (One differs more from the original title, one less), one for the abstract expanded by the title (No more than 300 words), one for the experiment you think the subjects underwent according to the title (No more than 100 words), and one for the keywords you think the title covers in brain science.\nTitle: Neural Patterns of Reorganization after Intensive Robot-Assisted Virtual Reality Therapy and Repetitive Task Practice in Patients with Chronic Stroke'},
              {'role': 'assistant', 'content': 'Title1:The Influence of Robot-Assisted Virtual Reality Therapy and Repetitive Task Practice on Neural Reorganization Patterns in Chronic Stroke Patients\nTitle2:Neural Reorganization Patterns in Chronic Stroke Patients following Intensive Robot-Assisted Virtual Reality Therapy and Repetitive Task Practice\nAbstract:Several approaches to rehabilitation of the hand following a stroke have emerged over the last two decades. These treatments, including repetitive task practice (RTP), robotically assisted rehabilitation and virtual rehabilitation activities, produce improvements in hand function but have yet to reinstate function to pre-stroke levels-which likely depends on developing the therapies to impact cortical reorganization in a manner that favors or supports recovery. Understanding cortical reorganization that underlies the above interventions is therefore critical to inform how such therapies can be utilized and improved and is the focus of the current investigation. Specifically, we compare neural reorganization elicited in stroke patients participating in two interventions: a hybrid of robot-assisted virtual reality (RAVR) rehabilitation training and a program of RTP training. Ten chronic stroke subjects participated in eight 3-h sessions of RAVR therapy. Another group of nine stroke subjects participated in eight sessions of matched RTP therapy. Functional magnetic resonance imaging (fMRI) data were acquired during paretic hand movement, before and after training. We compared the difference between groups and sessions (before and after training) in terms of BOLD intensity, laterality index of activation in sensorimotor areas, and the effective connectivity between ipsilesional motor cortex (iMC), contralesional motor cortex, ipsilesional primary somatosensory cortex (iS1), ipsilesional ventral premotor area (iPMv), and ipsilesional supplementary motor area. Last, we analyzed the relationship between changes in fMRI data and functional improvement measured by the Jebsen Taylor Hand Function Test (JTHFT), in an attempt to identify how neurophysiological changes are related to motor improvement. Subjects in both groups demonstrated motor recovery after training, but fMRI data revealed RAVR-specific changes in neural reorganization patterns. First, BOLD signal in multiple regions of interest was reduced and re-lateralized to the ipsilesional side. Second, these changes correlated with improvement in JTHFT scores. Our findings suggest that RAVR training may lead to different neurophysiological changes when compared with traditional therapy. This effect may be attributed to the influence that augmented visual and haptic feedback during RAVR training exerts over higher-order somatosensory and visuomotor areas.\nExperiment:The study involved 19 subjects with chronic ischemic stroke who participated in a 2-week training program, either robotic-assisted virtual rehabilitation (RAVR) or a traditional rehabilitation program (RTP). The RAVR group trained using a virtual reality system with robot assistance, while the RTP group received traditional rehabilitation. Both groups were trained for 4 days a week, 3 hours per day. The study aimed to examine the effectiveness of RAVR in promoting recovery of upper extremity function, as demonstrated by improved Modified Ashworth Scale scores and Chedoke-McMaster Impairment Inventory stages.\nKeywords:stroke, virtual reality, rehabilitation, motor control and learning/plasticity, functional magnetic resonance imaging neuroimaging, connectivity analysis'},
              {'role': 'user', 'content': 'Please return two rewrites of the title (One differs more from the original title, one less), one for the abstract expanded by the title (No more than 300 words), one for the experiment you think the subjects underwent according to the title (No more than 100 words), and one for the keywords you think the title covers in brain science.\nTitle: {}'.format(text)}
          ]
      )
      return response['choices'][0]['message']['content']

    except openai.error.APIError as e:
      if try_number == max_tries - 1:
        print('APIError')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.Timeout as e:
      if try_number == max_tries - 1:
        print('Timeout')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.APIConnectionError as e:
      if try_number == max_tries - 1:
        print('APIConnectionError')
        return '\n'
      else:
        time.sleep(0.1)

main

In [None]:
while True: # augment brain datasets
  try:
    openai.api_key = '[your_key]'

    neuroquery_meta = pd.read_table('/content/drive/MyDrive/Colab/CoordinateGPT/data-neuroquery_version-1_metadata.tsv.gz') # artilces data
    out_dir = '/content/drive/MyDrive/Colab/CoordinateGPT/ChatAUG/AUG_1'

    for index in tqdm(neuroquery_meta.index):
      title = neuroquery_meta.at[index, 'title']
      id = int(neuroquery_meta.at[index, 'id'])

      save_dir = os.path.join(out_dir, str(id) + '.npy')
      if os.path.exists(save_dir):
        continue

      text = tuple([title])
      text_aug = text2aug(text) # augmentation of five types of data

      A = text_aug.split('Title1:')
      B = A[1].split('Title2:')
      C = B[1].split('Abstract:')
      D = C[1].split('Experiment:')
      E = D[1].split('Keywords:')

      Title1 = B[0][:-1]
      Title2 = C[0][:-1]
      Abstract = D[0][:-1]
      Experiment = E[0][:-1]
      Keyword = E[1]

      aug = np.array({'Title1': Title1, 'Title2': Title2, 'Abstract': Abstract, 'Experiment': Experiment, 'Keyword': Keyword})
      np.save(save_dir, aug)

      time.sleep(1.25)


    break # end loops

  except Exception as e: # avoid aborting chatgpt when it encounters an error
    time.sleep(5)
    continue

# Chat2Brain_Train

In [None]:
!pip install -q transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [13]:
import os
import sys
import transformers
import torch


import torch.nn as nn
import numpy as np
import pandas as pd
import nibabel as nib


from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from argparse import ArgumentParser

model_decoder

In [9]:
class SimpleConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=2)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, input_):
        identity = self.conv1(input_)
        residue = self.bn1(identity)
        residue = self.act_fn(residue)
        residue = self.conv2(residue)
        out = identity + residue
        out = self.bn2(out)
        out = self.act_fn(out)
        return out


class TransConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.trans_conv1 = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.trans_conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ImageDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn=nn.Sigmoid, num_filter=256):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_filter = num_filter
        act_fn = nn.Hardtanh(min_val=-6, max_val=6)

        self.trans_1 = TransConvResBlock3D(self.in_channels, self.num_filter, act_fn)
        self.trans_2 = TransConvResBlock3D(self.num_filter, self.num_filter // 2, act_fn)
        self.trans_3 = TransConvResBlock3D(self.num_filter // 2, self.num_filter // 4, act_fn)

        self.out = SimpleConvResBlock3D(self.num_filter // 4, self.out_channels, act_fn)

    def forward(self, input_):
        up_1 = self.trans_1(input_)
        up_2 = self.trans_2(up_1)
        up_3 = self.trans_3(up_2)

        out = self.out(up_3)

        return out[:, :, 1:, 1:, 1:]

model_main_with_encoder

In [10]:
class Text2BrainModel(nn.Module):
    def __init__(self, out_channels, fc_channels, decoder_filters, pretrained_bert_dir, decoder_act_fn=nn.Sigmoid, drop_p=0.5, decoder_input_shape=[4, 5, 4]):
        super().__init__()
        self.out_channels = out_channels
        self.fc_channels = fc_channels
        self.decoder_filters = decoder_filters
        self.decoder_input_shape = decoder_input_shape
        self.drop_p = drop_p

        self.tokenizer = transformers.BertTokenizer.from_pretrained(pretrained_bert_dir)
        self.encoder = transformers.BertModel.from_pretrained(pretrained_bert_dir)
        if torch.cuda.is_available():
          self.encoder = self.encoder.cuda()

        self.fc = nn.Linear(
          in_features=768,
          out_features=self.decoder_input_shape[0]*self.decoder_input_shape[1]*self.decoder_input_shape[2]*self.fc_channels)
        self.dropout = nn.Dropout(self.drop_p)
        self.relu = nn.ReLU()

        self.decoder = ImageDecoder(in_channels=self.fc_channels, out_channels=1, num_filter=self.decoder_filters, act_fn=decoder_act_fn)


    def forward(self, texts):
        batch = [self._tokenize(x) for x in texts]

        in_mask = self._pad_mask(batch, batch_first=True)
        in_ = pad_sequence(batch, batch_first=True)
        if torch.cuda.is_available():
          in_ = in_.cuda()
          in_mask = in_mask.cuda()

        _, embedding = self.encoder(in_, attention_mask=in_mask)

        x = self.dropout(embedding)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.relu(x)

        decoder_tensor_shape = [-1, self.fc_channels] + self.decoder_input_shape
        x = x.view(decoder_tensor_shape)

        out = self.decoder(x)

        return out


    def _tokenize(self, text):
        return self.tokenizer.encode(text, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512).squeeze(0)


    def _pad_mask(self, sequences, batch_first=False):
        ret = [torch.ones(len(s)) for s in sequences]
        return pad_sequence(ret, batch_first=batch_first)

dataloader

In [12]:
class Chat2BrainDataset(Dataset):
    def __init__(self, metadata, text_dir, brain_dir, source):
        self.metadata = metadata
        self.text_dir = text_dir
        self.brain_dir = brain_dir
        self.source = source


    def __getitem__(self, index):
        row = self.metadata.iloc[index]
        id = row['id']

        if self.source == 'title':
            text = row['title']
        elif self.source == 'aug':
            Aug = np.load(os.path.join(self.text_dir, str(id) + '.npy'), allow_pickle=True)
            Title = row['title']

            text = [Title, Aug]
        else:
            raise Exception("Data source not implemented")
        text = text.lower()

        brain_map = nib.load(os.path.join(self.brain_dir, str(id) + '.nii.gz'))
        brain_map = brain_map.get_fdata()
        brain_map = brain_map[3:-3, 3:-4, :-6]
        brain_map = np.expand_dims(brain_map, 0)
        brain_map = brain_map / np.max(brain_map)
        brain_map = np.nan_to_num(brain_map, copy=False)

        return text, torch.cuda.FloatTensor(brain_map)


    def __len__(self):
        return len(self.metadata.index)

utils

In [15]:
def compute_corr_coeff(A, B):
    # rowwise mean of input arrays & subtract from input arrays
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)

    # corr coeff
    return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None]))


def save_checkpoint(model, optimizer, scheduler, epoch, fname, output_dir):
    checkpoint = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'scheduler': scheduler.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, os.path.join(output_dir, fname))

args

In [14]:
def init_args():
    parser = ArgumentParser()

    parser.add_argument("--gpus", type=str,
                        default="0, 1, 2, 3",
                        help="Which gpus to use?")

    parser.add_argument("--ver", type=str,
                        default="neuroquery",
                        help="Additional string for the name of the file")

    parser.add_argument("--train_csv",
                        type=str,
                        help="Path to the csv containing the training articles data")

    parser.add_argument("--val_csv",
                        type=str,
                        help="Path to the csv containing the validation articles data")

    parser.add_argument("--images_dir",
                        type=str,
                        help="Directory containing activation maps, should be of size (40, 48, 40)")

    parser.add_argument("--pretrained_bert_dir",
                        type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/scibert_scivocab_uncased",
                        help="Directory containing pretrained BERT model")

    parser.add_argument("--pretrained_tokenizer_dir",
                        type=str,
                        help="Directory containing pretrained tokenizer")

    parser.add_argument("--mask_file",
                        type=str,
                        help="Brain mask file")

    parser.add_argument("--save_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_checkpoint/",
                        help="Path to the output directory")

    parser.add_argument("--save_test_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_test/",
                        help="Path to the output directory")

    parser.add_argument("--mask_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data/brain_maps/neuroquery/",
                        help="Path to the mask directory")

    parser.add_argument("--text_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data/ChatAUG/",
                        help="Path to the text directory")

    parser.add_argument("--metadata_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data-neuroquery_version-1_metadata.tsv.gz",
                        help="Path to the metadata directory")

    parser.add_argument("--n_fc_channels",
                        type=int,
                        default=1024,
                        help="Base number of channels in the FC layer")

    parser.add_argument("--n_decoder_channels",
                        type=int,
                        default=256,
                        help="Base number of channels in the image decoder")

    parser.add_argument("--n_output_channels",
                        type=int,
                        default=1,
                        help="Number of output channels")

    parser.add_argument("--lr",
                        type=float,
                        default=3e-2,
                        help="Learning rate")

    parser.add_argument("--weight_decay",
                        type=float,
                        default=1e-6,
                        help="Weight decay of the optimizer")

    parser.add_argument("--drop_p",
                        type=float,
                        default=0.6,
                        help="Dropout proportion for FC layer")

    parser.add_argument("--epochs",
                        type=int,
                        default=550,
                        help="Training epochs")

    parser.add_argument("--seed",
                        type=int,
                        default=28)

    parser.add_argument("--random_seed",
                        type=int,
                        default=60)

    parser.add_argument("--split",
                        type=list,
                        default=[6, 2, 2])

    parser.add_argument("--checkpoint_file",
                        type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_checkpoint/neuroquery_title_fc1024_dec256_lr0.03_decay1e-06_drop0.6_seed28/checkpoint_1450.pth",
                        help="Path to the checkpoint file to be loaded into the model")

    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=10,
                        help="Number of epochs between saved checkpoints")

    parser.add_argument("--batch_size",
                        type=int,
                        default=24,
                        help="Batch size")

    parser.add_argument("--Scaling_factor",
                        type=int,
                        default=1,
                        help="Scaling factor")

    parser.add_argument("--phrase",
                        type=str,
                        default=None,
                        help="Input phrase for prediction")

    parser.add_argument("--source",
                        type=str,
                        default="title",
                        help="Source type")

    return parser.parse_args()

train

In [16]:
def train(model, train_loader, optimizer, loss_fn, mask):
    model.train()
    avg_loss = 0
    avg_corr = 0

    for batch_idx, (text, brain) in enumerate(train_loader):
        brain_map = (args.Scaling_factor * brain).cuda()

        optimizer.zero_grad()

        if args.source == 'title':
            output = model(text)

            loss = loss_fn(output, brain_map)
            loss.backward()
            optimizer.step()
            avg_loss += (loss.item() / len(train_loader))
        elif args.source == 'aug':
            Title = text[0]
            Title1 = text[1]['Title1']
            Title2 = text[1]['Title2']
            Abstract = text[1]['Abstract']
            Experiment = text[1]['Experiment']
            Keywords = text[1]['Keywords']

            for text_with_aug in [Title, Title1, Title2, Abstract, Experiment, Keywords, Title]:
                optimizer.zero_grad()

                output = model(text_with_aug)

                loss = loss_fn(output, brain_map)
                loss.backward()
                optimizer.step()
                avg_loss += (loss.item() / len(train_loader * 7))

        output_np = output.cpu().detach().numpy()
        target_np = brain_map.cpu().detach().numpy()

        all_corr = compute_corr_coeff(
            output_np.reshape(output_np.shape[0], -1),
            target_np.reshape(output_np.shape[0], -1))
        corr = np.mean(np.diag(all_corr))
        if np.isnan(corr):
            print(text,
                  np.isnan(output_np).any(),
                  np.isnan(target_np).any(), all_corr, corr)
            print("Output", torch.max(output), output)
            print("Target", torch.max(brain_map), brain_map)
            sys.exit(1)

        avg_corr = avg_corr + corr / len(train_loader)

        if batch_idx % 100 == 99:
            print('[{}/{} ({:.0f}%)] Loss: {:.6f} Corr: {:.6f}'.format(
                batch_idx * len(text), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), corr))

    print('  Train: avg loss: {:.6f} - avg corr: {:.6f}'.format(avg_loss, avg_corr))

    return avg_loss, avg_corr

eval

In [None]:
def eval(model, val_loader, loss_fn, mask):
    model.eval()
    avg_loss = 0
    avg_corr = 0
    with torch.no_grad():
        for batch_idx, (text, brain) in enumerate(val_loader):
            brain_map = (args.Scaling_factor * brain).cuda()

            if args.source == 'title':
                output = model(text)

                loss = loss_fn(output, brain_map)
                avg_loss += (loss.item() / len(val_loader))
            elif args.source == 'aug':
                Title = text[0]
                Title1 = text[1]['Title1']
                Title2 = text[1]['Title2']
                Abstract = text[1]['Abstract']
                Experiment = text[1]['Experiment']
                Keywords = text[1]['Keywords']

                for text_with_aug in [Title, Title1, Title2, Abstract, Experiment, Keywords, Title]:

                    output = model(text_with_aug)

                    loss = loss_fn(output, brain_map)
                    avg_loss += (loss.item() / len(val_loader * 7))

            output_np = output.cpu().detach().numpy()[:, :, mask]
            target_np = brain_map.cpu().detach().numpy()[:, :, mask]
            corr = np.mean(
                np.diag(
                    compute_corr_coeff(
                        output_np.reshape(output_np.shape[0], -1),
                        target_np.reshape(output_np.shape[0], -1))))
            avg_corr = avg_corr + corr / len(val_loader)

        print('  Val: avg loss: {:.6f} - avg corr: {:.6f}'.format(
            avg_loss, avg_corr))

        return avg_loss, avg_corr

main

In [None]:
if __name__=='__main__':
    # init
    args = init_args()

    os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    output_name = f'continue_{args.ver}_{args.source}_fc{args.n_fc_channels}_dec{args.n_decoder_channels}_lr{args.lr}_decay{args.weight_decay}_drop{args.drop_p}_seed{args.seed}'
    output_dir = os.path.join(args.save_dir, output_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    else:
        print(f'Output dir exists: {output_dir}')

    mask_dir = args.mask_dir

    writer = SummaryWriter(os.path.join(output_dir, "logs")) # tensorboard

    # load Data
    np.random.seed(args.seed)
    meta_data = pd.read_table(args.metadata_dir)

    # randomly divide the dataset into training, validation and test sets
    train_meta = meta_data.sample(frac=args.split[0] / (args.split[0] + args.split[1] + args.split[2]), random_state=args.random_seed)

    between_meta = meta_data.append(train_meta).drop_duplicates(keep=False)

    val_meta = between_meta.sample(frac=args.split[1] / (args.split[1] + args.split[2]), random_state=args.random_seed)
    test_meta = between_meta.append(val_meta).drop_duplicates(keep=False)

    train_meta.to_csv(os.path.join(args.save_dir, 'train_meta.csv'))
    val_meta.to_csv(os.path.join(args.save_dir, 'val_meta.csv'))
    test_meta.to_csv(os.path.join(args.save_dir, 'test_meta.csv'))

    train_dataset = Chat2BrainDataset(train_meta, args.text_dir, args.mask_dir, args.source)
    val_dataset = Chat2BrainDataset(val_meta, args.text_dir, args.mask_dir, args.source)

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False)
    print("Number of training articles:", len(train_dataset))
    print("Number of validation articles:", len(val_dataset))

    # init model
    model = Text2BrainModel(
        out_channels=1,
        fc_channels=args.n_fc_channels,
        decoder_filters=args.n_decoder_channels,
        pretrained_bert_dir=args.pretrained_bert_dir,
        drop_p=args.drop_p)
    model.cuda()

    # loading checkpoint
    if args.checkpoint_file is not None:
        state_dict = torch.load(args.checkpoint_file)['state_dict']
        model.load_state_dict(state_dict)

    # optimizer
    num_training_steps = len(train_dataset) * args.epochs
    num_warmup_steps = num_training_steps // 3
    opt = transformers.AdamW([
        {'params': model.fc.parameters()},
        {'params': model.decoder.parameters()},
        {'params': model.encoder.parameters(), 'lr': 1e-5},
    ], lr=args.lr, weight_decay=args.weight_decay)
    sched = transformers.get_linear_schedule_with_warmup(opt, num_warmup_steps, num_training_steps)

    loss_fn = nn.MSELoss(reduction="sum")

    val_losses = []
    val_corrs = []
    best_loss = sys.float_info.max
    best_corr = 0.0

    for epoch in tqdm(range(args.epochs)):
        train_loss, train_corr = train(model, train_loader, opt, loss_fn, [2, 4])
        val_loss, val_corr = eval(model, val_loader, loss_fn, [2, 4])

        writer.add_scalar('training loss', train_loss, epoch)
        writer.add_scalar('training corr', train_corr, epoch)

        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('validation corr', val_corr, epoch)

        val_losses.append(val_loss)
        val_corrs.append(val_corr)

        mean_loss = np.mean(val_losses[-args.checkpoint_interval:])
        mean_corr = np.mean(val_corrs[-args.checkpoint_interval:])

        if (epoch > args.epochs * 0.1) and (epoch % args.checkpoint_interval == 0):
            if mean_loss < best_loss:
                save_checkpoint(model, opt, sched, epoch, "best_loss.pth", output_dir)
                best_loss = mean_loss
            if mean_corr > best_corr:
                save_checkpoint(model, opt, sched, epoch, "best_corr.pth", output_dir)
                best_corr = mean_corr
            save_checkpoint(model, opt, sched, epoch, f'checkpoint_{epoch}.pth', output_dir)
        sched.step()
    save_checkpoint(model, opt, sched, args.epochs, f'checkpoint_{args.epochs}.pth', output_dir)
    writer.close()

# Chat2Brain_Train_DDP

In [None]:
!pip install -q transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [18]:
import torch.distributed as dist


import os
import sys
import transformers
import argparse
import torch


import torch.nn as nn
import numpy as np
import pandas as pd
import nibabel as nib


from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from argparse import ArgumentParser

model_decoder

In [None]:
class SimpleConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=2)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, input_):
        identity = self.conv1(input_)
        residue = self.bn1(identity)
        residue = self.act_fn(residue)
        residue = self.conv2(residue)
        out = identity + residue
        out = self.bn2(out)
        out = self.act_fn(out)
        return out


class TransConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.trans_conv1 = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.trans_conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ImageDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn=nn.Sigmoid, num_filter=256):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_filter = num_filter
        act_fn = nn.Hardtanh(min_val=-6, max_val=6)

        self.trans_1 = TransConvResBlock3D(self.in_channels, self.num_filter, act_fn)
        self.trans_2 = TransConvResBlock3D(self.num_filter, self.num_filter // 2, act_fn)
        self.trans_3 = TransConvResBlock3D(self.num_filter // 2, self.num_filter // 4, act_fn)

        self.out = SimpleConvResBlock3D(self.num_filter // 4, self.out_channels, act_fn)


    def forward(self, input_):
        up_1 = self.trans_1(input_)
        up_2 = self.trans_2(up_1)
        up_3 = self.trans_3(up_2)

        out = self.out(up_3)

        return out[:, :, 1:, 1:, 1:]

model_main_with_encoder

In [None]:
class Text2BrainModel(nn.Module):
    def __init__(self, out_channels, fc_channels, decoder_filters, pretrained_bert_dir, decoder_act_fn=nn.Sigmoid, drop_p=0.5, decoder_input_shape=[4, 5, 4]):
        super().__init__()
        self.out_channels = out_channels
        self.fc_channels = fc_channels
        self.decoder_filters = decoder_filters
        self.decoder_input_shape = decoder_input_shape
        self.drop_p = drop_p

        self.tokenizer = transformers.BertTokenizer.from_pretrained(pretrained_bert_dir)
        self.encoder = transformers.BertModel.from_pretrained(pretrained_bert_dir)
        if torch.cuda.is_available():
          self.encoder = self.encoder.cuda()

        self.fc = nn.Linear(
          in_features=768,
          out_features=self.decoder_input_shape[0]*self.decoder_input_shape[1]*self.decoder_input_shape[2]*self.fc_channels)
        self.dropout = nn.Dropout(self.drop_p)
        self.relu = nn.ReLU()

        self.decoder = ImageDecoder(in_channels=self.fc_channels, out_channels=1, num_filter=self.decoder_filters, act_fn=decoder_act_fn)


    def forward(self, texts):
        batch = [self._tokenize(x) for x in texts]

        in_mask = self._pad_mask(batch, batch_first=True)
        in_ = pad_sequence(batch, batch_first=True)
        if torch.cuda.is_available():
          in_ = in_.cuda()
          in_mask = in_mask.cuda()

        _, embedding = self.encoder(in_, attention_mask=in_mask)

        x = self.dropout(embedding)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.relu(x)

        decoder_tensor_shape = [-1, self.fc_channels] + self.decoder_input_shape
        x = x.view(decoder_tensor_shape)

        out = self.decoder(x)

        return out


    def _tokenize(self, text):
        return self.tokenizer.encode(text, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512).squeeze(0)


    def _pad_mask(self, sequences, batch_first=False):
        ret = [torch.ones(len(s)) for s in sequences]
        return pad_sequence(ret, batch_first=batch_first)

dataloader

In [None]:
class Chat2BrainDataset(Dataset):
    def __init__(self, metadata, text_dir, brain_dir, source):
        self.metadata = metadata
        self.text_dir = text_dir
        self.brain_dir = brain_dir
        self.source = source


    def __getitem__(self, index):
        row = self.metadata.iloc[index]
        id = row['id']

        if self.source == 'title':
            text = row['title']
            text = text.lower()
        elif self.source == 'aug':
            Aug = np.load(os.path.join(self.text_dir, str(id) + '.npy'), allow_pickle=True)
            Title = row['title']

            text = Aug.tolist()
            text['Title'] = Title.lower()
            text['Title1'] = text['Title1'].lower()
            text['Title2'] = text['Title2'].lower()
            text['Abstract'] = text['Abstract'].lower()
            text['Experiment'] = text['Experiment'].lower()
            text['Keyword'] = text['Keyword'].lower()

        else:
            raise Exception("Data source not implemented")

        brain_map = nib.load(os.path.join(self.brain_dir, str(id) + '.nii.gz'))
        brain_map = brain_map.get_fdata()
        brain_map = brain_map[3:-3, 3:-4, :-6]
        brain_map = np.expand_dims(brain_map, 0)
        brain_map = brain_map / np.max(brain_map)
        brain_map = np.nan_to_num(brain_map, copy=False)
        brain_map = brain_map.astype(np.float32)
        return text, brain_map


    def __len__(self):
        return len(self.metadata.index)

utils

In [None]:
def compute_corr_coeff(A, B):
    # rowwise mean of input arrays & subtract from input arrays
    A_mA = A - A.mean(1)[:, None]
    B_mB = B - B.mean(1)[:, None]

    # sum of squares across rows
    ssA = (A_mA**2).sum(1)
    ssB = (B_mB**2).sum(1)

    # corr coeff
    return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None], ssB[None]))


def save_checkpoint(model, optimizer, scheduler, epoch, fname, output_dir):
    checkpoint = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'scheduler': scheduler.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, os.path.join(output_dir, fname))

utils_ddp

In [None]:
def reduce_mean(tensor, nprocs):
    # average loss on different threads
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= nprocs
    return rt

train

In [None]:
def train(local_rank, model, train_loader, optimizer, loss_fn, source):
    model.train()
    avg_loss = 0
    avg_corr = 0

    for batch_idx, (text, brain) in enumerate(train_loader):
        brain_map = torch.FloatTensor(1 * brain).cuda(local_rank)

        if source == 'title':
            optimizer.zero_grad()
            output = model(text)

            loss = loss_fn(output, brain_map)
            loss.backward()
            optimizer.step()
            loss = reduce_mean(loss, dist.get_world_size())
            avg_loss += (loss.item() / (len(train_loader) * 7))
        elif source == 'aug':
            Title = text['Title']

            Title1 = text['Title1']
            Title2 = text['Title2']
            Abstract = text['Abstract']
            Experiment = text['Experiment']
            Keyword = text['Keyword']

            for text_with_aug in [Title, Title1, Title2, Abstract, Experiment, Keyword, Title]:
                optimizer.zero_grad()

                output = model(text_with_aug)

                loss = loss_fn(output, brain_map)
                loss.backward()
                optimizer.step()
                loss = reduce_mean(loss, dist.get_world_size())
                avg_loss += (loss.item() / (len(train_loader) * 7))

        output_np = output.cpu().detach().numpy()
        target_np = brain_map.cpu().detach().numpy()

        all_corr = compute_corr_coeff(
            output_np.reshape(output_np.shape[0], -1),
            target_np.reshape(output_np.shape[0], -1))
        corr = np.mean(np.diag(all_corr))
        if np.isnan(corr):
            print(text,
                  np.isnan(output_np).any(),
                  np.isnan(target_np).any(), all_corr, corr)
            print("Output", torch.max(output), output)
            print("Target", torch.max(brain_map), brain_map)
            sys.exit(1)

        avg_corr = avg_corr + corr / (len(train_loader) * 7)

        if batch_idx % 100 == 99:
            print('[{}/{} ({:.0f}%)] Loss: {:.6f} Corr: {:.6f}'.format(
                batch_idx * len(text), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), corr))

    print('  Train: avg loss: {:.6f} - avg corr: {:.6f}'.format(avg_loss, avg_corr))

    return avg_loss, avg_corr

eval

In [None]:
def eval(local_rank, model, val_loader, loss_fn, source):
    model.eval()
    avg_loss = 0
    avg_corr = 0
    with torch.no_grad():
        for batch_idx, (text, brain) in enumerate(val_loader):
            brain_map = torch.FloatTensor(1 * brain).cuda(local_rank)

            if source == 'title':
                output = model(text)

                loss = loss_fn(output, brain_map)
                loss = reduce_mean(loss, dist.get_world_size())
                avg_loss += (loss.item() / (len(val_loader) * 7))
            elif source == 'aug':
                Title = text['Title']

                Title1 = text['Title1']
                Title2 = text['Title2']
                Abstract = text['Abstract']
                Experiment = text['Experiment']
                Keyword = text['Keyword']

                for text_with_aug in [Title, Title1, Title2, Abstract, Experiment, Keyword, Title]:

                    output = model(text_with_aug)

                    loss = loss_fn(output, brain_map)
                    loss = reduce_mean(loss, dist.get_world_size())
                    avg_loss += (loss.item() / (len(val_loader) * 7))

            output_np = output.cpu().detach().numpy()
            target_np = brain_map.cpu().detach().numpy()
            corr = np.mean(
                np.diag(
                    compute_corr_coeff(
                        output_np.reshape(output_np.shape[0], -1),
                        target_np.reshape(output_np.shape[0], -1))))
            avg_corr = avg_corr + corr / (len(val_loader) * 7)

        print('  Val: avg loss: {:.6f} - avg corr: {:.6f}'.format(
            avg_loss, avg_corr))

        return avg_loss, avg_corr

main

python -m torch.distributed.launch --nproc_per_node=**[num_gpu(int)]** **[python script]**

eg:

python -m torch.distributed.launch --nproc_per_node=**8 Chat2Brain_train.py**

In [None]:
if __name__=='__main__':
    # init, can't load arguments with args when using DDP
    # args = init_args()

    os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
    os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3, 4, 5, 6, 7'

    # init ddp
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int)
    args = parser.parse_args()

    dist.init_process_group(backend='nccl', world_size=8, rank=args.local_rank) # nccl is the best backend, world_size should be equal to num_gpu
    torch.cuda.set_device(args.local_rank)

    batch_size = 48 // dist.get_world_size() # total batchsize is the sum of each gpu's batchsize

    output_name = f'neuroquery_aug_fc1024_dec256_lr3e-2_decay1e-6_drop0.6_seed28'
    output_dir = os.path.join("/data3/weiyaonai/project/chat2brain/chat2brain/Chat2Brain_checkpoint/", output_name)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    else:
        if dist.get_rank() == 0:
            print(f'Output dir exists: {output_dir}')

    mask_dir = "/data3/weiyaonai/project/chat2brain/chat2brain/data/brain_maps/neuroquery/"

    if dist.get_rank() == 0:
        writer = SummaryWriter(os.path.join(output_dir, "logs")) # tensorboard

    # load Data
    np.random.seed(28)
    meta_data = pd.read_table("/data3/weiyaonai/project/chat2brain/chat2brain/data-neuroquery_version-1_metadata.tsv.gz")

    # randomly divide the dataset into training, validation and test sets
    train_meta = meta_data.sample(frac=6 / (6 + 2 + 2), random_state=60)

    between_meta = meta_data.append(train_meta).drop_duplicates(keep=False)

    val_meta = between_meta.sample(frac=2 / (2 + 2), random_state=60)

    train_dataset = Chat2BrainDataset(train_meta, "/data3/weiyaonai/project/chat2brain/chat2brain/data/ChatAUG/AUG_1/", "/data3/weiyaonai/project/chat2brain/chat2brain/data/brain_maps/neuroquery/", "aug")
    val_dataset = Chat2BrainDataset(val_meta, "/data3/weiyaonai/project/chat2brain/chat2brain/data/ChatAUG/AUG_1/", "/data3/weiyaonai/project/chat2brain/chat2brain/data/brain_maps/neuroquery/", "aug")

    # data parallelism
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, sampler=val_sampler)
    if dist.get_rank() == 0:
        print("Number of training articles:", len(train_dataset))
        print("Number of validation articles:", len(val_dataset))

    # init model
    model = Text2BrainModel(
        out_channels=1,
        fc_channels=1024,
        decoder_filters=256,
        pretrained_bert_dir="/data3/weiyaonai/project/chat2brain/chat2brain/scibert_scivocab_uncased",
        drop_p=0.6).cuda(args.local_rank)
    # model.cuda()
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # improved accuracy for small batch sizes

    # models sent to different gpu, not models in parallel
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[args.local_rank],
                                                      output_device=args.local_rank,
                                                      find_unused_parameters=True,
                                                      broadcast_buffers=True
                                                      )

    # Optimizer
    num_training_steps = len(train_dataset) * 2000
    num_warmup_steps = num_training_steps // 3
    opt = transformers.AdamW([
        {'params': model.module.fc.parameters()},
        {'params': model.module.decoder.parameters()},
        {'params': model.module.encoder.parameters(), 'lr': 1e-5},
    ], lr=3e-2, weight_decay=1e-6)
    sched = transformers.get_linear_schedule_with_warmup(opt, num_warmup_steps, num_training_steps)

    loss_fn = nn.MSELoss(reduction="sum").cuda(args.local_rank)

    val_losses = []
    val_corrs = []
    best_loss = sys.float_info.max
    best_corr = 0.0

    for epoch in tqdm(range(2000)):
        each_dist_train_data_num = ((len(train_dataset) % dist.get_world_size()) + len(
            train_dataset)) / dist.get_world_size()

        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)

        train_loss, train_corr = train(args.local_rank, model, train_loader, opt, loss_fn, 'aug')
        val_loss, val_corr = eval(args.local_rank, model, val_loader, loss_fn, 'aug')

        dist.barrier()

        if args.local_rank == 0: # only work in local_rank 0 (first gpu)
            writer.add_scalar('training loss', train_loss, epoch)
            writer.add_scalar('training corr', train_corr, epoch)

            writer.add_scalar('validation loss', val_loss, epoch)
            writer.add_scalar('validation corr', val_corr, epoch)

        val_losses.append(val_loss)
        val_corrs.append(val_corr)

        mean_loss = np.mean(val_losses[-10:])
        mean_corr = np.mean(val_corrs[-10:])

        if dist.get_rank() == 0:
            if (epoch > 2000 * 0.1) and (epoch % 10 == 0):
                if mean_loss < best_loss:
                    save_checkpoint(model, opt, sched, epoch, "best_loss.pth", output_dir)
                    best_loss = mean_loss
                if mean_corr > best_corr:
                    save_checkpoint(model, opt, sched, epoch, "best_corr.pth", output_dir)
                    best_corr = mean_corr
                save_checkpoint(model, opt, sched, epoch, f'checkpoint_{epoch}.pth', output_dir)
            sched.step()
        save_checkpoint(model, opt, sched, epoch, f'checkpoint_{epoch}.pth', output_dir)

    if dist.get_rank() == 0:
        writer.close()

# Chat2Brain_Test

In [None]:
!pip install -q rouge
!pip install -q openai
!pip install -q evaluate
!pip install -q rouge_score

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import time
import openai
import evaluate
import concurrent.futures


import numpy as np
import pandas as pd


from rouge import Rouge
from rouge.rouge import rouge_score
from tqdm import tqdm

chatgpt

In [None]:
def chat2brain_gpt(text, near_samples=None, interactive=False, former_good_response=None, former_bad_response=None):
  max_tries = 2

  dynamic_messages = [
      {"role": "system", "content": "You are a brain science expert. Rewrite a research paper title \
       that contains key information about the brain science based on the given original text. \
        Please ensure that your response is concise and does not exceed the length of the research paper title."}]
  if near_samples is not None:
    dynamic_messages.append({"role": "user", "content": "Here are some examples. Please learn how to write research paper title in these examples, and \
    pay particular attention to the use of brain science concepts in the following example."})
    for near in near_samples:
      aug_path = os.path.join(aug_dir, str(near["id"]) + ".npy")
      aug_npy = np.load(aug_path, allow_pickle=True).tolist()

      dynamic_messages.append({"role": "user", "content": "What are the research paper title that contains key information about the brain science based on the given original text:\
          \nTEXT:\n{}".format(aug_npy["Title1"].replace('\n', ''))})
      dynamic_messages.append({"role": "assistant", "content": "TITLE:\n{}".format(near["title"].replace('\n', ''))})

  for try_number in range(max_tries):
    try:
      if not interactive:
        dynamic_messages.append({"role": "user", "content": "What are the research paper title that contains key information about the brain science \
         based on the given original text: \nTEXT: {}".format(text)})

        response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo",
          messages=dynamic_messages )
      else:
        dynamic_messages.append({"role": "user", "content": "What are the research paper title that contains key information about the brain science \
                        based on the given original text: \nTEXT: {}".format(text)})
        if former_bad_response is not None:
          for index, item in enumerate(former_bad_response):
            dynamic_messages.append({"role": "user", "content": "Below is a bad research paper title based on the given original text above:"})
            dynamic_messages.append({"role": "assistant", "content": "TITLE:\n{}".format(item["title"].replace('\n', ''))})

          if len(former_good_response) == 0:
            dynamic_messages.append({"role": "user", "content": "Please give another research paper title based on the given original text. \
            It is important to note that your answer should avoid being consistent with the bad research paper title above. \
            Please pay particular attention to the consistent use of phrases in the above examples"})

        if former_good_response is not None:
          for index, item in enumerate(former_good_response):
            dynamic_messages.append({"role": "user", "content": "Below is an excellent research paper title based on the given original text above:"})
            dynamic_messages.append({"role": "assistant", "content": "TITLE:\n{}".format(item["title"].replace('\n', ''))})

          if len(former_bad_response) == 0:
            dynamic_messages.append({"role": "user", "content": "This is an excellent research paper title, \
              please give another research paper title in this format, and do not exceed the length of this research paper title.\
              Please pay particular attention to the consistent use of phrases in the above examples"})
          else:
            dynamic_messages.append({"role": "user", "content": "Please give another research paper title based on the given original text above. \
              It is important to note that your answer should avoid being consistent with the bad research paper title above, \
              but should be consistent with the excellent research paper title above, and do not exceed the length of the excellent research paper title. \
              And please pay particular attention to the consistent use of phrases in the above examples"})


        response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo",
          messages=dynamic_messages
        )
      return response['choices'][0]['message']['content']

    except openai.error.APIError as e:
      if try_number == max_tries - 1:
        print('APIError')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.Timeout as e:
      if try_number == max_tries - 1:
        print('Timeout')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.APIConnectionError as e:
      if try_number == max_tries - 1:
        print('APIConnectionError')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.RateLimitError as e:
      if try_number == max_tries - 1:
        print('RateLimitError')
        return '\n'
      else:
        time.sleep(0.1)
    except openai.error.InvalidRequestError as e:
      if try_number == max_tries - 1:
        print('InvalidRequestError')
        return -2
      else:
        time.sleep(0.1)

finding similar examples in the corpus

In [None]:
def compute_near(test_sample, train_data, near_num, rouge_type):
  distance_sample = []
  for index in train_data.index:
    id = int(train_data.at[index, 'id'])
    title = train_data.at[index, 'title']

    scores = rouge.get_scores(test_sample, title)
    distance_sample.append({'id': id, 'score': scores[0][rouge_type]['f'], 'title': title})

  sort_dis = sorted(distance_sample, key=lambda x:x.__getitem__('score'), reverse=True)
  near_dis = sort_dis[:near_num]

  return near_dis

impressionGPT for Chat2Brain

In [None]:
def chat2brain_test(args):
  row_index, row = args

  id, title = row['id'], row['title']

  text_path = os.path.join(text_dir, str(id) + '.npy')
  mask_path = os.path.join(mask_dir, str(id) + 'nii.gz')

  outdir = os.path.join('/content/drive/MyDrive/Colab/CoordinateGPT/result/output', str(id) + '.npy')
  if os.path.exists(outdir):
    final_best_response = np.load(outdir)
    final_best_response = final_best_response.tolist()

    return row_index, final_best_response

  near_sample = compute_near(title, train_meta, near_k_samples, rouge_type)

  good_response = []
  bad_response = []
  former_score = 0

  all_response_score, all_response = [], []
  try_count = 0

  while True:
    if len(good_response) == 0 and len(bad_response) == 0:
      try:
        response = chat2brain_gpt(title, near_samples=near_sample)
      except Exception as e:
        time.sleep(5)
        continue
    else:
      try:
        response = chat2brain_gpt(title, near_samples=near_sample, interactive=interactive, former_good_response=good_reponse, former_bad_response=bad_reponse)
      except Exception as e:
        time.sleep(5)
        continue

    if response == -2:
      print("exceed length, pop 2 similar examples")
      for i in range(2):
        near_sample.pop()
      continue

    response = response.replace('\n', '')

    compare_scores = []
    for near_sa in near_sample:
      train_index = train_meta[train_meta['id']==near_sa['id']].index
      scores = rouge.get_scores([response], [train_meta.at[train_index[0], 'title']])
      rouge_score = scores[0][rouge_type]['f']
      compare_scores.append(rouge_score)

    score = np.mean(np.array(compare_scores))

    all_response_score.append(score)
    all_response.append(response)
    try_count += 1

    if score >= rouge_thre and score > former_score:
      former_score = score
      good_response.clear()
      good_response.append({'title': response, 'score': score})

    if score < rouge_thre and score < former_score:
      if len(bad_response) > 8:
        bad_response = bad_response[-8:]
      bad_response.append({'title': response, 'score': score})

    if try_count > interactive_times:
      break

  max_score_index = all_response_score.index(max(all_response_score))
  max_score = all_response_score[max_score_index]
  final_best_response = all_response[max_score_index]

  final_best_response = np.array(final_best_response)
  np.save(outdir, final_best_response)

  return row_index, final_best_response

main

In [None]:
rouge = Rouge()

openai.api_key = '[openai_key]'

interactive = True
interactive_times = 10 # upper limit on the number of times a single sample can interact with chatgpt
rouge_thre = 0.7 # thresholds for determining excellent response
near_k_samples = 14 # number of approximate samples
rouge_type = 'rouge-1'

aug_dir = '/content/drive/MyDrive/Colab/CoordinateGPT/ChatAUG/AUG_1'

train_meta = pd.read_csv('/content/drive/MyDrive/Colab/CoordinateGPT/train_meta.csv')
val_meta = pd.read_csv('/content/drive/MyDrive/Colab/CoordinateGPT/val_meta.csv')
test_meta = pd.read_csv('/content/drive/MyDrive/Colab/CoordinateGPT/test_meta.csv')

text_dir = '/content/drive/MyDrive/Colab/CoordinateGPT/ChatAUG/AUG_1'
mask_dir = '/content/drive/MyDrive/Colab/CoordinateGPT/neuroquery_brain_maps'
save_path = '/content/drive/MyDrive/Colab/CoordinateGPT/result'
if not os.path.exists(save_path):
  os.makedirs(save_path)
print(f'Save path: {save_path}')

test_meta.drop(test_meta.columns[0], axis=1, inplace=True)

with concurrent.futures.ThreadPoolExecutor() as executor:
  result = dict(tqdm(executor.map(chat2brain_test, test_meta.iterrows()), total=len(test_meta)))

test_meta['Result'] = test_meta.index.map(result.get)
out_file_name = 'test_time{}_thre{}_near{}_{}_num{}.csv'.format(interactive_times, rouge_thre, near_k_samples, rouge_type)
output_file = os.path.join(save_path, '{}'.format(out_file_name))
test_meta.to_csv(output_file, index=False)

# Chat2Brain_Score

In [None]:
!pip install -q transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import sys
import torch
import transformers
import scipy.stats


import numpy as np
import pandas as pd
import torch.nn as nn
import nibabel as nib


from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from argparse import ArgumentParser

model_decoder

In [None]:
class SimpleConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=2)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)

    def forward(self, input_):
        identity = self.conv1(input_)
        residue = self.bn1(identity)
        residue = self.act_fn(residue)
        residue = self.conv2(residue)
        out = identity + residue
        out = self.bn2(out)
        out = self.act_fn(out)
        return out


class TransConvResBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn):
        super().__init__()
        self.trans_conv1 = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.act_fn = act_fn

    def forward(self, input_):
        out = self.trans_conv1(input_)
        out = self.bn1(out)
        out = self.act_fn(out)
        return out


class ImageDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, act_fn=nn.Sigmoid, num_filter=256):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_filter = num_filter
        act_fn = nn.Hardtanh(min_val=-6, max_val=6)

        self.trans_1 = TransConvResBlock3D(self.in_channels, self.num_filter, act_fn)
        self.trans_2 = TransConvResBlock3D(self.num_filter, self.num_filter // 2, act_fn)
        self.trans_3 = TransConvResBlock3D(self.num_filter // 2, self.num_filter // 4, act_fn)

        self.out = SimpleConvResBlock3D(self.num_filter // 4, self.out_channels, act_fn)

    def forward(self, input_):
        up_1 = self.trans_1(input_)
        up_2 = self.trans_2(up_1)
        up_3 = self.trans_3(up_2)

        out = self.out(up_3)

        return out[:, :, 1:, 1:, 1:]

model_main_with_encoder

In [None]:
class Text2BrainModel(nn.Module):
    def __init__(self, out_channels, fc_channels, decoder_filters, pretrained_bert_dir, decoder_act_fn=nn.Sigmoid, drop_p=0.5, decoder_input_shape=[4, 5, 4]):
        super().__init__()
        self.out_channels = out_channels
        self.fc_channels = fc_channels
        self.decoder_filters = decoder_filters
        self.decoder_input_shape = decoder_input_shape
        self.drop_p = drop_p

        self.tokenizer = transformers.BertTokenizer.from_pretrained(pretrained_bert_dir)
        self.encoder = transformers.BertModel.from_pretrained(pretrained_bert_dir)
        if torch.cuda.is_available():
          self.encoder = self.encoder.cuda()

        self.fc = nn.Linear(
          in_features=768,
          out_features=self.decoder_input_shape[0]*self.decoder_input_shape[1]*self.decoder_input_shape[2]*self.fc_channels)
        self.dropout = nn.Dropout(self.drop_p)
        self.relu = nn.ReLU()

        self.decoder = ImageDecoder(in_channels=self.fc_channels, out_channels=1, num_filter=self.decoder_filters, act_fn=decoder_act_fn)


    def forward(self, texts):
        batch = [self._tokenize(x) for x in texts]

        in_mask = self._pad_mask(batch, batch_first=True)
        in_ = pad_sequence(batch, batch_first=True)
        if torch.cuda.is_available():
          in_ = in_.cuda()
          in_mask = in_mask.cuda()

        _, embedding = self.encoder(in_, attention_mask=in_mask)

        x = self.dropout(embedding)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.relu(x)

        decoder_tensor_shape = [-1, self.fc_channels] + self.decoder_input_shape
        x = x.view(decoder_tensor_shape)

        out = self.decoder(x)

        return out


    def _tokenize(self, text):
        return self.tokenizer.encode(text, add_special_tokens=True, return_tensors='pt', truncation=True, max_length=512).squeeze(0)


    def _pad_mask(self, sequences, batch_first=False):
        ret = [torch.ones(len(s)) for s in sequences]
        return pad_sequence(ret, batch_first=batch_first)

args

In [None]:
def init_args():
    parser = ArgumentParser()

    parser.add_argument("--gpus", type=str,
                        default="0, 1, 2, 3",
                        help="Which gpus to use?")

    parser.add_argument("--ver", type=str,
                        default="neuroquery",
                        help="Additional string for the name of the file")

    parser.add_argument("--train_csv",
                        type=str,
                        help="Path to the csv containing the training articles data")

    parser.add_argument("--val_csv",
                        type=str,
                        help="Path to the csv containing the validation articles data")

    parser.add_argument("--images_dir",
                        type=str,
                        help="Directory containing activation maps, should be of size (40, 48, 40)")

    parser.add_argument("--pretrained_bert_dir",
                        type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/scibert_scivocab_uncased",
                        help="Directory containing pretrained BERT model")

    parser.add_argument("--pretrained_tokenizer_dir",
                        type=str,
                        help="Directory containing pretrained tokenizer")

    parser.add_argument("--mask_file",
                        type=str,
                        help="Brain mask file")

    parser.add_argument("--save_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_checkpoint/",
                        help="Path to the output directory")

    parser.add_argument("--save_test_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_test/",
                        help="Path to the output directory")

    parser.add_argument("--mask_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data/brain_maps/neuroquery/",
                        help="Path to the mask directory")

    parser.add_argument("--text_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data/ChatAUG/",
                        help="Path to the text directory")

    parser.add_argument("--metadata_dir", type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/data-neuroquery_version-1_metadata.tsv.gz",
                        help="Path to the metadata directory")

    parser.add_argument("--n_fc_channels",
                        type=int,
                        default=1024,
                        help="Base number of channels in the FC layer")

    parser.add_argument("--n_decoder_channels",
                        type=int,
                        default=256,
                        help="Base number of channels in the image decoder")

    parser.add_argument("--n_output_channels",
                        type=int,
                        default=1,
                        help="Number of output channels")

    parser.add_argument("--lr",
                        type=float,
                        default=3e-2,
                        help="Learning rate")

    parser.add_argument("--weight_decay",
                        type=float,
                        default=1e-6,
                        help="Weight decay of the optimizer")

    parser.add_argument("--drop_p",
                        type=float,
                        default=0.6,
                        help="Dropout proportion for FC layer")

    parser.add_argument("--epochs",
                        type=int,
                        default=550,
                        help="Training epochs")

    parser.add_argument("--seed",
                        type=int,
                        default=28)

    parser.add_argument("--random_seed",
                        type=int,
                        default=60)

    parser.add_argument("--split",
                        type=list,
                        default=[6, 2, 2])

    parser.add_argument("--checkpoint_file",
                        type=str,
                        default="/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_checkpoint/neuroquery_title_fc1024_dec256_lr0.03_decay1e-06_drop0.6_seed28/checkpoint_1450.pth",
                        help="Path to the checkpoint file to be loaded into the model")

    parser.add_argument("--checkpoint_interval",
                        type=int,
                        default=10,
                        help="Number of epochs between saved checkpoints")

    parser.add_argument("--batch_size",
                        type=int,
                        default=24,
                        help="Batch size")

    parser.add_argument("--Scaling_factor",
                        type=int,
                        default=1,
                        help="Scaling factor")

    parser.add_argument("--phrase",
                        type=str,
                        default=None,
                        help="Input phrase for prediction")

    parser.add_argument("--source",
                        type=str,
                        default="title",
                        help="Source type")

    return parser.parse_args()

score

In [None]:
def norm(x):
  x = (x - np.min(x)) / (np.max(x) - np.min(x))

  return x


def get_score(result_r, result_o, gt):
  auc_all_r = 0
  dice_all_r = 0
  pval_all_r = 0

  auc_all_o = 0
  dice_all_o = 0
  pval_all_o = 0

  for i in range(len(gt)):
    i += 1
    true = gt[i-1].flatten()
    pred_r = result_r[i-1].flatten()
    pred_o = result_o[i-1].flatten()

    true_1 = true.copy()
    true_1 = norm(true_1)

    pred_1_r = pred_r.copy()
    pred_1_o = pred_o.copy()

    pred_1_r = norm(pred_1_r)
    pred_1_o = norm(pred_1_o)

    auc_r = roc_auc_score(true_1.astype('int'), pred_1_r)
    auc_all_r += auc_r

    auc_o = roc_auc_score(true_1.astype('int'), pred_1_o)
    auc_all_o += auc_o

    dice_r = f1_score(true_1.astype('int'), pred_1_r.astype('int'))
    dice_all_r += dice_r

    dice_o = f1_score(true_1.astype('int'), pred_1_o.astype('int'))
    dice_all_o += dice_o

    t_r, pval_r = scipy.stats.ttest_ind(true, pred_r)
    pval_all_r += pval_r

    t_o, pval_o = scipy.stats.ttest_ind(true, pred_o)
    pval_all_o += pval_o

    writer.add_scalars("test_score", {
        'Auc_r': auc_r,
        'Dice_r': dice_r,
        'Pval_r': pval_r,
        'Auc_o': auc_o,
        'Dice_o': dice_o,
        'Pval_o': pval_o,
    }, i)


  mean_auc_r = auc_all_r / i
  mean_dice_r = dice_all_r / i
  mean_pval_r = pval_all_r / i

  mean_auc_o = auc_all_o / i
  mean_dice_o = dice_all_o / i
  mean_pval_o = pval_all_o / i
  print(mean_auc_r, mean_dice_r, mean_pval_r, mean_auc_o, mean_dice_o, mean_pval_o)
  return mean_auc_r, mean_dice_r, mean_pval_r, mean_auc_o, mean_dice_o, mean_pval_o

main

In [None]:
if __name__=='__main__':
    # init
    args = init_args()

    os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

    output_dir = os.path.join(args.save_test_dir, 'title')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    else:
        print(f'Output dir exists: {output_dir}')

    writer = SummaryWriter(os.path.join(output_dir, "test_time10_thre0.7_near14_rouge-1_num_logs"))

    # load Data
    test_meta = pd.read_csv('/disk1/wyn/workshop/ChatGPT/text2brain-main/Result/test_time10_thre0.7_near14_rouge-1_num10.csv')
    brain_map_dir = '/disk1/wyn/workshop/ChatGPT/text2brain-main/data/brain_maps/neuroquery'
    output = '/disk1/wyn/workshop/ChatGPT/text2brain-main/Result/output/'

    # init model
    model = Text2BrainModel(
        out_channels=1,
        fc_channels=args.n_fc_channels,
        decoder_filters=args.n_decoder_channels,
        pretrained_bert_dir=args.pretrained_bert_dir,
        drop_p=args.drop_p)
    model.cuda()

    # loading checkpoint
    state_dict = torch.load('/disk1/wyn/workshop/ChatGPT/text2brain-main/Chat2Brain_checkpoint/aug_loss.pth')['state_dict']
    model.load_state_dict(state_dict)

    # test
    gt = []
    r = []
    o = []

    for index, row in test_meta.iterrows():
        id = row['id']

        text_r = row['Result']
        text_r = text_r.replace('TITLE:', '')
        text_r = text_r.replace('\n', '')
        text_r = text_r.lstrip() # optimized text by impressionGPT

        text_o = row['title'] # original text

        brain_map = nib.load(os.path.join(brain_map_dir, str(id) + '.nii.gz'))
        brain_map = brain_map.get_fdata()
        brain_map = brain_map[3:-3, 3:-4, :-6]
        brain_map = brain_map / np.max(brain_map)
        brain_map = np.nan_to_num(brain_map, copy=False)

        gt.append(brain_map)

        model.eval()
        with torch.no_grad():
            predict_r = model([text_r])
            predict_o = model([text_o])

        predict_r = predict_r.cpu().detach().numpy()
        predict_o = predict_o.cpu().detach().numpy()

        predict_r = np.squeeze(predict_r)
        predict_r = predict_r / np.max(predict_r)
        predict_r = np.nan_to_num(predict_r, copy=False)

        predict_o = np.squeeze(predict_o)
        predict_o = predict_o / np.max(predict_o)
        predict_o = np.nan_to_num(predict_o, copy=False)

        r.append(predict_r)
        o.append(predict_o)



    gt = np.array(gt)
    o = np.array(o)
    r = np.array(r)

    np.save(os.path.join(output, 'gt.npy'), gt)
    np.save(os.path.join(output, 'o.npy'), o)
    np.save(os.path.join(output, 'r.npy'), r)

    mean_auc_r, mean_dice_r, mean_pval_r, mean_auc_o, mean_dice_o, mean_pval_o = get_score(r, o, gt)
    writer.close()

# Visualization

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
sample_index = 2
thresholds = 0.8
Res_result = np.load('/content/drive/MyDrive/Colab/CoordinateGPT/result/r.npy')

m, n, k = np.where(Res_result[sample_index]>thresholds)
fig, ax = plt.subplots(subplot_kw=dict(projection='3d'))
scatter = ax.scatter(m, n, k)