In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install fsspec -q
!pip install gcsfs -q

In [None]:
!pip install tensorflow_addons -q
!pip install -U efficientnet -q
!pip install vit-keras -q
!pip install git+https://github.com/qubvel/classification_models.git -q

In [None]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import random
import os
import re
import math
import sys
sys.path.append('/content/drive/MyDrive/TFNFNet')

tqdm.pandas()

import matplotlib.pyplot as plt
import cv2

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn import metrics

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_hub as hub

# policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
# tf.keras.mixed_precision.set_global_policy(policy)

import efficientnet.tfkeras as efn
from vit_keras import vit
from classification_models.tfkeras import Classifiers

import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

from nfnet import NFNet

def set_seed(seed = 0):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
import warnings
warnings.filterwarnings('ignore')

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
EPOCHS = 5
FOLDS = 5
TRAIN_FOLD = 1 # 0 -> 1 -> 2 -> 3 -> 4
OLD_DIM = 256
DIM = 600
SEED = 42
IMAGE_SIZE = [DIM, DIM]
AUTO = tf.data.experimental.AUTOTUNE
GCS_DS_PATH = 'gs://kds-03e1a10468eed52cf1f7962028d75ca9e9ca67e4b3c30a4f640075f4'
LR = 5e-5
BATCH_SIZE = strategy.num_replicas_in_sync*4
PRETRAINED_MODELV2 =  'efficientnetv2-l-21k'
PRETRAINED_MODEL = 'seresnext50'
NFNET_VARIANT = 'F3'
VIT_VARIANT = 'vit_b32'
hub_url = f'gs://cloud-tpu-checkpoints/efficientnet/v2/hub/{PRETRAINED_MODELV2}/feature-vector'
MODEL_GCS_PATH = 'gs://kds-5b3917d3d77c2d5e35aa3dff601deb72749b3b749ff524b335978b15'

MODELS = {
    'B0': efn.EfficientNetB0,
    'B1': efn.EfficientNetB1,
    'B2': efn.EfficientNetB2,
    'B3': efn.EfficientNetB3,
    'B4': efn.EfficientNetB4,
    'B5': efn.EfficientNetB5,
    'B6': efn.EfficientNetB6,
    'B7': efn.EfficientNetB7
}

In [None]:
labels = pd.read_csv('/content/drive/MyDrive/Kaggle/train_labels.csv')
labels['path'] = labels['id'].progress_apply(lambda x:  f"{GCS_DS_PATH}/train/{x[0]}/{x}.npy")
old_labels = pd.read_csv('/content/drive/MyDrive/Kaggle/train_labels_old.csv')
old_labels['path'] = old_labels['id'].progress_apply(lambda x:  f"{GCS_DS_PATH}/old_leaky_data/train_old/{x[0]}/{x}.npy")
# labels = labels.append(old_labels).reset_index(drop=True)
labels = old_labels

In [None]:
def build_decoder(with_labels=True, target_size=(256, 256), ext='npy'):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        if ext == 'npy':
            img = tf.io.decode_raw(file_bytes, tf.float16)
            img = img[64:]
            img = tf.reshape(img, [6, 273, 256])
            img = tf.concat([img[0,:,:], img[2,:,:], img[4,:,:]], axis=0)
            img = tf.stack([img, img, img], axis=-1)
            img = tf.cast(img, tf.float32) 
            img = tf.image.resize(img, target_size)
        else:
            if ext == 'png':
                img = tf.image.decode_png(file_bytes, channels=3)
            elif ext in ['jpg', 'jpeg']:
                img = tf.image.decode_jpeg(file_bytes, channels=3)
            else:
                raise ValueError("Image extension not supported")
            img = tf.cast(img, tf.float32) / 255.0
            img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
    def augment(img):
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_flip_up_down(img)
        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

In [None]:
def transform(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 1.* tf.random.normal([1], dtype='float32', seed=42)
    shr = tf.random.normal([1], dtype='float32', seed=42) 
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32', seed=42)/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32', seed=42)/10.
    h_shift = 16. * tf.random.normal([1], dtype='float32', seed=42) 
    w_shift = 16. * tf.random.normal([1], dtype='float32', seed=42) 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = tf.keras.backend.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = tf.keras.backend.cast(idx2,dtype='int32')
    idx2 = tf.keras.backend.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3]),label

def transform_test(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 1.* tf.random.normal([1], dtype='float32', seed=42)
    shr = tf.random.normal([1], dtype='float32', seed=42) 
    h_zoom = 1.0 + tf.random.normal([1], dtype='float32', seed=42)/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32', seed=42)/10.
    h_shift = 16. * tf.random.normal([1], dtype='float32', seed=42) 
    w_shift = 16. * tf.random.normal([1], dtype='float32', seed=42) 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = tf.keras.backend.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = tf.keras.backend.cast(idx2,dtype='int32')
    idx2 = tf.keras.backend.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return tf.keras.backend.dot(tf.keras.backend.dot(rotation_matrix, shear_matrix), tf.keras.backend.dot(zoom_matrix, shift_matrix))

In [None]:
AUG_BATCH = BATCH_SIZE

def cutmix(image, label, PROBABILITY = 1.0):
    
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    DIM = IMAGE_SIZE[0]
    CLASSES = 1
    label = tf.cast(label, 'float32')
    
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,AUG_BATCH),tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM,y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM,x+WIDTH//2)
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast(WIDTH*WIDTH/DIM/DIM,tf.float32)

        lab1 = label[j,]
        lab2 = label[k,]
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(AUG_BATCH, DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(AUG_BATCH, 1))
    return image2,label2

def mixup(image, label, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with mixup applied
    DIM = IMAGE_SIZE[0]
    CLASSES = 1
    label = tf.cast(label, 'float32')
    
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        # DO MIXUP WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.float32)
        # CHOOSE RANDOM
        k = tf.cast( tf.random.uniform([],0,AUG_BATCH),tf.int32)
        a = 0.5 #tf.random.uniform([],0,1)*P # this is beta dist with alpha=1.0
        # MAKE MIXUP IMAGE
        img1 = image[j,]
        img2 = image[k,]
        imgs.append((1-a)*img1 + a*img2)
        # MAKE CUTMIX LABEL

        lab1 = label[j,]
        lab2 = label[k,]
        labs.append((1-a)*lab1 + a*lab2)

    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(AUG_BATCH,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(AUG_BATCH,1))
    return image2,label2

def transform_cutmix_mixup(image,label):
    # THIS FUNCTION APPLIES BOTH CUTMIX AND MIXUP
    DIM = IMAGE_SIZE[0]
    SWITCH = 0.5
    CUTMIX_PROB = 0.666
    MIXUP_PROB = 0.666
    label = tf.cast(label, 'float32')
    # FOR SWITCH PERCENT OF TIME WE DO CUTMIX AND (1-SWITCH) WE DO MIXUP
    image2, label2 = cutmix(image, label, CUTMIX_PROB)
    image3, label3 = mixup(image, label, MIXUP_PROB)
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        P = tf.cast( tf.random.uniform([],0,1)<=SWITCH, tf.float32)
        imgs.append(P*image2[j,]+(1-P)*image3[j,])
        labs.append(P*label2[j,]+(1-P)*label3[j,])
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image4 = tf.reshape(tf.stack(imgs),(AUG_BATCH,DIM,DIM,3))
    label4 = tf.reshape(tf.stack(labs),(AUG_BATCH,1))
    return image4,tf.reshape(tf.cast(tf.math.round(label4), 'int64'), (-1,))

In [None]:
# One of:

##################### Qubvel's Effnets #####################

# def CNN(MODEL):
#     model = tf.keras.models.Sequential()
#     model.add(MODEL(weights='imagenet', include_top=False, input_shape=(*(IMAGE_SIZE), 3)))
#     model.add(tf.keras.layers.GlobalAveragePooling2D())
#     model.add(tf.keras.layers.Dropout(0.5))
#     model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

#     model.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
#         loss=tf.keras.losses.BinaryCrossentropy(),
#         metrics=[tf.keras.metrics.AUC(name='AUC')]
#     )
    
#     return model

##################### Qubvel's Classifiers #####################

# def CNN():
#     model = tf.keras.models.Sequential()
#     model.add(Classifiers.get(PRETRAINED_MODEL)[0](weights='imagenet', include_top=False, input_shape=(*(IMAGE_SIZE), 3)))
#     model.add(tf.keras.layers.GlobalAveragePooling2D())
#     model.add(tf.keras.layers.Dropout(0.5))
#     model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

#     model.compile(
#         optimizer=tfa.optimizers.SWA(tf.keras.optimizers.Adam(learning_rate=LR)),
#         loss=tf.keras.losses.BinaryCrossentropy(),
#         metrics=[tf.keras.metrics.AUC(name='AUC')]
#     )
    
#     return model

##################### Qubvel's EffnetL2 with xhlulu's pretrained weights #####################

# def CNN():
#     model = tf.keras.models.Sequential()
#     model.add(efn.EfficientNetL2(weights='/content/drive/MyDrive/EfficientNetL2_pretrained/efficientnet-l2_noisy-student_notop.h5', drop_connect_rate=0, include_top=False, input_shape=(*(IMAGE_SIZE), 3)))
#     model.add(tf.keras.layers.GlobalAveragePooling2D())
#     model.add(tf.keras.layers.Dropout(0.5))
#     model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

#     model.compile(
#         optimizer=tfa.optimizers.SWA(tf.keras.optimizers.Adam(learning_rate=LR), start_averaging=2),
#         loss=tf.keras.losses.BinaryCrossentropy(),
#         metrics=[tf.keras.metrics.AUC(name='AUC')]
#     )
    
#     return model

##################### faustomorales's ViT #####################

# def CNN():
#   model = vit.vit_b32(
#       image_size=(DIM, DIM),
#       activation='sigmoid',
#       pretrained=True,
#       include_top=True,
#       pretrained_top=False,
#       classes=1
#   )

#   model.compile(
#       optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
#       loss=tf.keras.losses.BinaryCrossentropy(),
#       metrics=[tf.keras.metrics.AUC(name='AUC')]
#   )
  
#   return model

##################### Google's EffnetV2 #####################

# def CNN():
#     model = tf.keras.models.Sequential()
#     model.add(tf.keras.layers.InputLayer(input_shape=(DIM, DIM, 3)))
#     model.add(hub.KerasLayer(hub_url, trainable=True))
#     model.add(tf.keras.layers.Dropout(0.5))
#     model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

#     model.compile(
#         optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
#         loss=tf.keras.losses.BinaryCrossentropy(),
#         metrics=[tf.keras.metrics.AUC(name='AUC')]
#     )
    
#     return model

##################### Google's NFNet with hoangthang1607's pretrained weights #####################

def CNN():
    
    inputs = tf.keras.layers.Input(shape=(DIM, DIM, 3))
    
    nfnet_model = NFNet(
        num_classes=1000, 
        variant=NFNET_VARIANT,
    )    
    nfnet_model.load_weights(f'{MODEL_GCS_PATH}/NFNets_weights/{NFNET_VARIANT}_NFNet/{NFNET_VARIANT}_NFNet')
    
    x = nfnet_model(inputs)['pool']
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
    
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    
    model.build((None, DIM, DIM, 3))
       
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[tf.keras.metrics.AUC(name='AUC')]
    )
    
    return model

In [None]:
LR_START = LR
LR_MAX = LR*4
LR_MIN = LR/10
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def CustomSchedule(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr

def ExponentialSchedule(epoch, lr):
    rate = DECAY_RATE
    return lr*tf.math.exp(-rate*(epoch)).numpy()


CustomCallback = tf.keras.callbacks.LearningRateScheduler(CustomSchedule, verbose = True)
EarlyStopping = tf.keras.callbacks.EarlyStopping(monitor='val_AUC', patience=10, mode='max', restore_best_weights=True)
ReduceOnPlateau = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_AUC', mode='max', min_lr=1e-8)

In [None]:
paths = labels['path'].values
targets = labels['target'].values

In [None]:
skf = StratifiedKFold(n_splits=FOLDS, shuffle=True, random_state=42)
oof_preds = pd.DataFrame({'True': targets, 'Preds': np.zeros(len(targets))})

for fold, (train_idx, valid_idx) in enumerate(skf.split(paths, targets)):
  if fold == TRAIN_FOLD:
      print('Fold:', fold)

      X_filenames, Xvalid_filenames = paths[train_idx], paths[valid_idx]
      train_targets, valid_targets = targets[train_idx], targets[valid_idx]

      TRAIN_STEPS = X_filenames.shape[0]//BATCH_SIZE
      VALID_STEPS = Xvalid_filenames.shape[0]//BATCH_SIZE

      decoder = build_decoder(with_labels=True, target_size=IMAGE_SIZE, ext='npy')
      train_dataset = build_dataset(
          X_filenames, 
          train_targets,
          bsize=BATCH_SIZE, 
          decode_fn=decoder
      )
      valid_dataset = build_dataset(
          Xvalid_filenames,
          valid_targets,
          bsize=BATCH_SIZE, 
          decode_fn=decoder,
          shuffle=False, 
          repeat=False, 
          augment=False
      )
      train_dataset = train_dataset.map(mixup, num_parallel_calls=AUTO)

      with strategy.scope():
          model = CNN()

      model.fit(
          train_dataset,
          epochs=EPOCHS,
          steps_per_epoch=TRAIN_STEPS,
          validation_data=valid_dataset,
          validation_steps=VALID_STEPS,
          callbacks=[ReduceOnPlateau, EarlyStopping]
      )

      y_pred = model.predict(valid_dataset).reshape(-1,)
      print('AUC:', metrics.roc_auc_score(valid_targets, y_pred))
      
      # model.save(f'/content/drive/MyDrive/Kaggle/weights/Pretrained_weights/EfficientNet-L2/efficientnet_l2_mixup_{fold}_SWA.h5')
      # model.save(f'/content/drive/MyDrive/Kaggle/weights/Pretrained_weights/VIT-B32/{VIT_VARIANT}_{fold}.h5')

      save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
      model.save(f'/content/drive/MyDrive/Kaggle/weights/Pretrained_weights/NFNet-{NFNET_VARIANT}/{fold}', options=save_locally)

      oof_preds.loc[valid_idx, 'Preds'] = y_pred

      del model