From b145e64417e05ffcbee3f2888be384174abd38f5 Mon Sep 17 00:00:00 2001
From: Jonathan Calderon Chavez <chavezcalderon@google.com>
Date: Thu, 17 Oct 2024 20:00:46 +0000
Subject: [PATCH] Play around with colab image

---
 Dockerfile.tmpl            | 280 +++++++++----------------------------
 clean-layer.sh             |   2 -
 test                       |  23 ++-
 tests/common.py            |   3 +-
 tests/test_jax.py          |   2 +-
 tests/test_user_secrets.py |  25 ++++
 6 files changed, 113 insertions(+), 222 deletions(-)

diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl
index 11b05ebd..ce9be588 100644
--- a/Dockerfile.tmpl
+++ b/Dockerfile.tmpl
@@ -1,53 +1,4 @@
-ARG BASE_IMAGE_REPO \
-    BASE_IMAGE_TAG \
-    CPU_BASE_IMAGE_NAME \
-    GPU_BASE_IMAGE_NAME \
-    LIGHTGBM_VERSION \
-    TORCH_VERSION \
-    TORCHAUDIO_VERSION \
-    TORCHVISION_VERSION \
-    JAX_VERSION
-
-{{ if eq .Accelerator "gpu" }}
-FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
-FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
-FROM gcr.io/kaggle-images/python-jaxlib-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${JAX_VERSION} AS jaxlib_whl
-FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
-{{ else }}
-FROM ${BASE_IMAGE_REPO}/${CPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
-{{ end }}
-
-# Ensures shared libraries installed with conda can be found by the dynamic link loader.
-ENV LIBRARY_PATH="$LIBRARY_PATH:/opt/conda/lib" \
-    LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib"
-
-{{ if eq .Accelerator "gpu" }}
-ARG CUDA_MAJOR_VERSION \
-    CUDA_MINOR_VERSION
-ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION} \
-    CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
-# Make sure we are on the right version of CUDA
-RUN update-alternatives --set cuda /usr/local/cuda-$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION
-# NVIDIA binaries from the host are mounted to /opt/bin.
-ENV PATH=/opt/bin:${PATH} \
-    # Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
-    LD_LIBRARY_PATH_NO_STUBS="$LD_LIBRARY_PATH" \
-    LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64/stubs"
-RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
-{{ end }}
-
-# Keep these variables in sync if base image is updated.
-ENV TENSORFLOW_VERSION=2.16.1 \
-    # See https://github.com/tensorflow/io#tensorflow-version-compatibility
-    TENSORFLOW_IO_VERSION=0.37.0
-
-# We need to redefine the ARG here to get the ARG value defined above the FROM instruction.
-# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
-ARG LIGHTGBM_VERSION \
-    TORCH_VERSION \
-    TORCHAUDIO_VERSION \
-    TORCHVISION_VERSION \
-    JAX_VERSION
+FROM us-docker.pkg.dev/colab-images/public/runtime
 
 # Disable pesky logs like: KMP_AFFINITY: pid 6121 tid 6121 thread 0 bound to OS proc set 0
 # See: https://stackoverflow.com/questions/57385766/disable-tensorflow-log-information
@@ -78,99 +29,28 @@ RUN sed -i "s/httpredir.debian.org/debian.uchicago.edu/" /etc/apt/sources.list &
     apt-get install -y graphviz && pip install graphviz && \
     /tmp/clean-layer.sh
 
-# b/128333086: Set PROJ_DATA to points to the proj4 cartographic library.
-ENV PROJ_DATA=/opt/conda/share/proj
-
-# Install micromamba, setup channels, and replace conda with micromamba
-ENV MAMBA_ROOT_PREFIX=/opt/conda
-RUN curl -L "https://micro.mamba.pm/install.sh" -o /tmp/micromamba-install.sh \
-    && bash /tmp/micromamba-install.sh \
-    && rm /tmp/micromamba-install.sh \
-    && mv ~/.local/bin/micromamba /usr/bin/micromamba \
-    && (!(which conda) || cp /usr/bin/micromamba $(which conda)) \
-    && micromamba config append channels nvidia \
-    && micromamba config append channels rapidsai \
-    && micromamba config append channels conda-forge \
-    && micromamba config set channel_priority flexible \
-    && python -m nb_conda_kernels.install --disable
+# # b/128333086: Set PROJ_DATA to points to the proj4 cartographic library.
+# ENV PROJ_DATA=/opt/conda/share/proj
+
+# # Install micromamba, setup channels, and replace conda with micromamba
+# ENV MAMBA_ROOT_PREFIX=/opt/conda
+# RUN curl -L "https://micro.mamba.pm/install.sh" -o /tmp/micromamba-install.sh \
+#     && bash /tmp/micromamba-install.sh \
+#     && rm /tmp/micromamba-install.sh \
+#     && mv ~/.local/bin/micromamba /usr/bin/micromamba \
+#     && (!(which conda) || cp /usr/bin/micromamba $(which conda)) \
+#     && micromamba config append channels nvidia \
+#     && micromamba config append channels rapidsai \
+#     && micromamba config append channels conda-forge \
+#     && micromamba config set channel_priority flexible \
+#     && python -m nb_conda_kernels.install --disable
 
 # Install conda packages not available on pip.
 # When using pip in a conda environment, conda commands should be ran first and then
 # the remaining pip commands: https://www.anaconda.com/using-pip-in-a-conda-environment/
-RUN micromamba install -y mkl cartopy imagemagick pyproj "shapely<2" && \
-    rm -rf /opt/conda/lib/python3.10/site-packages/pyproj/proj_dir/ && \
-    /tmp/clean-layer.sh
-
-# Install spacy
-# b/232247930: uninstall pyarrow to avoid double installation with the GPU specific version.
-# b/341938540: unistall grpc-cpp to allow >=v24.4 cudf and cuml to be installed.
-{{ if eq .Accelerator "gpu" }}
-RUN pip uninstall -y pyarrow && \
-    micromamba install -vvvy spacy "cudf>=24.4" "cuml>=24.4" cupy cuda-version=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION && \
-    /tmp/clean-layer.sh
-{{ else }}
-RUN pip install spacy && \
-    /tmp/clean-layer.sh
-{{ end}}
-
-# Install PyTorch
-# b/356397043: magma-cuda121 is the latest version
-{{ if eq .Accelerator "gpu" }}
-COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
-# b/356397043: We are currently using cuda 12.3,
-# but magma-cuda121 is the latest compatible version 
-RUN micromamba install -y -c pytorch magma-cuda121 && \
-    pip install /tmp/torch/*.whl && \
-    sudo apt -y install libsox-dev && \
-    rm -rf /tmp/torch && \
-    /tmp/clean-layer.sh
-{{ else }}
-RUN pip install \
-        torch==$TORCH_VERSION+cpu \
-        torchvision==$TORCHVISION_VERSION+cpu \
-        torchaudio==$TORCHAUDIO_VERSION+cpu \
-        --index-url https://download.pytorch.org/whl/cpu && \
-    /tmp/clean-layer.sh
-{{ end }}
-
-# Install LightGBM
-{{ if eq .Accelerator "gpu" }}
-COPY --from=lightgbm_whl /tmp/whl/*.whl /tmp/lightgbm/
-# Install OpenCL (required by LightGBM GPU version)
-RUN apt-get install -y ocl-icd-libopencl1 clinfo && \
-    mkdir -p /etc/OpenCL/vendors && \
-    echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
-    pip install /tmp/lightgbm/*.whl && \
-    rm -rf /tmp/lightgbm && \
-    /tmp/clean-layer.sh
-{{ else }}
-RUN pip install lightgbm==$LIGHTGBM_VERSION && \
-    /tmp/clean-layer.sh
-{{ end }}
-
-# Install JAX
-{{ if eq .Accelerator "gpu" }}
-COPY --from=jaxlib_whl /tmp/whl/*.whl /tmp/jax/
-# b/319722433#comment9: Use pip wheels once versions matches our CUDA version.
-RUN pip install /tmp/jax/*.whl jax==$JAX_VERSION && \
-    /tmp/clean-layer.sh
-{{ else }}
-RUN pip install jax[cpu] && \
-    /tmp/clean-layer.sh
-{{ end }}
-
-
-# Install GPU specific packages
-{{ if eq .Accelerator "gpu" }}
-# Install GPU-only packages
-# No specific package for nnabla-ext-cuda 12.x minor versions.
-RUN export PATH=/usr/local/cuda/bin:$PATH && \
-    export CUDA_ROOT=/usr/local/cuda && \
-    pip install pycuda \
-        pynvrtc \
-        pynvml && \
-    /tmp/clean-layer.sh
-{{ end }}
+# RUN micromamba install -y mkl cartopy imagemagick pyproj "shapely<2" && \
+#     rm -rf /opt/conda/lib/python3.10/site-packages/pyproj/proj_dir/ && \
+#     /tmp/clean-layer.sh
 
 # b/308525631: Pin Matplotlib until seaborn can be upgraded
 # to >0.13.0 (now it's stuck by a package conflict with ydata-profiling 4.5.1).
@@ -195,28 +75,31 @@ RUN apt-get update && \
 
 RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh
 
+
+# Keep these variables in sync if base image is updated.
+ENV TENSORFLOW_VERSION=2.17.0 \
+    # See https://github.com/tensorflow/io#tensorflow-version-compatibility
+    TENSORFLOW_IO_VERSION=0.37.1
+
 RUN pip install \
         "tensorflow==${TENSORFLOW_VERSION}" \
         "tensorflow-io==${TENSORFLOW_IO_VERSION}" \
+        "tensorflow_hub>=0.16.0" \
         tensorflow-probability \
         tensorflow_decision_forests \
         tensorflow-text \
-        "tensorflow_hub>=0.16.0" \
         tf-keras \
-        "keras>3" \
         keras-cv \
         keras-nlp && \
     /tmp/clean-layer.sh
 
 ADD patches/keras_internal.py \
     patches/keras_internal_test.py \
-    /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/
+    /usr/local/lib/python3.10/dist-packages/tensorflow_decision_forests/keras/
 
 # b/350573866: xgboost v2.1.0 breaks learntools
 RUN apt-get install -y libfreetype6-dev && \
     apt-get install -y libglib2.0-0 libxext6 libsm6 libxrender1 libfontconfig1 --fix-missing && \
-    rm -rf /opt/conda/lib/python3.10/site-packages/numpy* && \
-    pip install "numpy==1.26.4" && \
     pip install gensim \
         textblob \
         wordcloud \
@@ -225,10 +108,7 @@ RUN apt-get install -y libfreetype6-dev && \
         hep_ml && \
     # NLTK Project datasets
     mkdir -p /usr/share/nltk_data && \
-    # NLTK Downloader no longer continues smoothly after an error, so we explicitly list
-    # the corpuses that work
-    # "yes | ..." answers yes to the retry prompt in case of an error. See b/133762095.
-    yes | python -m nltk.downloader -d /usr/share/nltk_data abc alpino averaged_perceptron_tagger \
+    python -m nltk.downloader -d /usr/share/nltk_data abc alpino averaged_perceptron_tagger \
     basque_grammars biocreative_ppi bllip_wsj_no_aux \
     book_grammars brown brown_tei cess_cat cess_esp chat80 city_database cmudict \
     comtrans conll2000 conll2002 conll2007 crubadan dependency_treebank \
@@ -377,10 +257,9 @@ RUN pip install annoy \
         mlcrate && \
     /tmp/clean-layer.sh
 
-# b/273059949: The pre-installed nbconvert is slow on html conversions and has to be force-uninstalled.
-# b/274619697: learntools also requires a specific nbconvert right now
-RUN rm -rf /opt/conda/lib/python3.10/site-packages/{nbconvert,nbclient,mistune,platformdirs}*
-
+# # b/273059949: The pre-installed nbconvert is slow on html conversions and has to be force-uninstalled.
+# # b/274619697: learntools also requires a specific nbconvert right now
+# RUN rm -rf /opt/conda/lib/python3.10/site-packages/{nbconvert,nbclient,mistune,platformdirs}*
 RUN pip install bleach \
         certifi \
         cycler \
@@ -446,8 +325,8 @@ RUN python -m spacy download en_core_web_sm && python -m spacy download en_core_
     #
     ###########
 
-RUN rm /opt/conda/lib/python3.10/site-packages/google*/direct_url.json && \
-    rm /opt/conda/lib/python3.10/site-packages/google*/REQUESTED
+# RUN rm /opt/conda/lib/python3.10/site-packages/google*/direct_url.json && \
+#     rm /opt/conda/lib/python3.10/site-packages/google*/REQUESTED
 # dlib has a libmkl incompatibility:
 # test_dlib_face_detector (test_dlib.TestDLib) ... INTEL MKL ERROR: /opt/conda/bin/../lib/libmkl_avx512.so.2: undefined symbol: mkl_sparse_optimize_bsr_trsm_i8.
 # Intel MKL FATAL ERROR: Cannot load libmkl_avx512.so.2 or libmkl_def.so.2.
@@ -476,9 +355,6 @@ RUN pip install wandb \
         Rtree \
         accelerate && \
         apt-get -y install libspatialindex-dev && \
-    # b/370860329: newer versions are not capable with current tensorflow
-    rm -rf /opt/conda/lib/python3.10/site-packages/numpy* && \
-    pip install "numpy==1.26.4" && \
     pip install pytorch-ignite \
         qgrid \
         bqplot \
@@ -510,9 +386,6 @@ RUN pip install wandb \
     pip install git+https://github.com/facebookresearch/segment-anything.git && \
     # b/370860329: newer versions are not capable with current tensorflow
     pip install --no-dependencies fastai fastdownload && \
-    # b/343971718: remove duplicate aiohttp installs, and reinstall it
-    rm -rf /opt/conda/lib/python3.10/site-packages/aiohttp* && \
-    micromamba install --force-reinstall -y aiohttp && \
     /tmp/clean-layer.sh
 
 # Download base easyocr models.
@@ -543,21 +416,21 @@ ENV TESSERACT_PATH=/usr/bin/tesseract \
     # For Theano with MKL
     MKL_THREADING_LAYER=GNU
 
-# Temporary fixes and patches
-# Temporary patch for Dask getting downgraded, which breaks Keras
-RUN pip install --upgrade dask && \
-    # Stop jupyter nbconvert trying to rewrite its folder hierarchy
-    mkdir -p /root/.jupyter && touch /root/.jupyter/jupyter_nbconvert_config.py && touch /root/.jupyter/migrated && \
-    mkdir -p /.jupyter && touch /.jupyter/jupyter_nbconvert_config.py && touch /.jupyter/migrated && \
-    # Stop Matplotlib printing junk to the console on first load
-    sed -i "s/^.*Matplotlib is building the font cache using fc-list.*$/# Warning removed by Kaggle/g" /opt/conda/lib/python3.10/site-packages/matplotlib/font_manager.py && \
-    # Make matplotlib output in Jupyter notebooks display correctly
-    mkdir -p /etc/ipython/ && echo "c = get_config(); c.IPKernelApp.matplotlib = 'inline'" > /etc/ipython/ipython_config.py && \
-    # Temporary patch for broken libpixman 0.38 in conda-forge, symlink to system libpixman 0.34 untile conda package gets updated to 0.38.5 or higher.
-    ln -sf /usr/lib/x86_64-linux-gnu/libpixman-1.so.0.34.0 /opt/conda/lib/libpixman-1.so.0.38.0 && \
-    # b/333854354: pin jupyter-server to version 2.12.5; later versions break LSP (b/333854354)
-    pip install --force-reinstall --no-deps jupyter_server==2.12.5 && \
-    /tmp/clean-layer.sh
+# # Temporary fixes and patches
+# # Temporary patch for Dask getting downgraded, which breaks Keras
+# RUN pip install --upgrade dask && \
+#     # Stop jupyter nbconvert trying to rewrite its folder hierarchy
+#     mkdir -p /root/.jupyter && touch /root/.jupyter/jupyter_nbconvert_config.py && touch /root/.jupyter/migrated && \
+#     mkdir -p /.jupyter && touch /.jupyter/jupyter_nbconvert_config.py && touch /.jupyter/migrated && \
+#     # Stop Matplotlib printing junk to the console on first load
+#     sed -i "s/^.*Matplotlib is building the font cache using fc-list.*$/# Warning removed by Kaggle/g" /opt/conda/lib/python3.10/site-packages/matplotlib/font_manager.py && \
+#     # Make matplotlib output in Jupyter notebooks display correctly
+#     mkdir -p /etc/ipython/ && echo "c = get_config(); c.IPKernelApp.matplotlib = 'inline'" > /etc/ipython/ipython_config.py && \
+#     # Temporary patch for broken libpixman 0.38 in conda-forge, symlink to system libpixman 0.34 untile conda package gets updated to 0.38.5 or higher.
+#     ln -sf /usr/lib/x86_64-linux-gnu/libpixman-1.so.0.34.0 /opt/conda/lib/libpixman-1.so.0.38.0 && \
+#     # b/333854354: pin jupyter-server to version 2.12.5; later versions break LSP (b/333854354)
+#     pip install --force-reinstall --no-deps jupyter_server==2.12.5 && \
+#     /tmp/clean-layer.sh
 
 # Fix to import bq_helper library without downgrading setuptools
 RUN mkdir -p ~/src && git clone https://github.com/SohierDane/BigQuery_Helper ~/src/BigQuery_Helper && \
@@ -565,44 +438,29 @@ RUN mkdir -p ~/src && git clone https://github.com/SohierDane/BigQuery_Helper ~/
     mv ~/src/BigQuery_Helper/bq_helper.py ~/src/BigQuery_Helper/bq_helper/__init__.py && \
     mv ~/src/BigQuery_Helper/test_helper.py ~/src/BigQuery_Helper/bq_helper/ && \
     sed -i 's/)/packages=["bq_helper"])/g' ~/src/BigQuery_Helper/setup.py && \
+    pip install setuptools==70.0.0 && \
     pip install -e ~/src/BigQuery_Helper && \
     /tmp/clean-layer.sh
 
-# Add BigQuery client proxy settings
-ENV PYTHONUSERBASE "/root/.local"
-ADD patches/kaggle_gcp.py \
-    patches/kaggle_secrets.py \
-    patches/kaggle_session.py \
-    patches/kaggle_web_client.py \ 
-    patches/kaggle_datasets.py \
-    patches/log.py \
-    patches/sitecustomize.py \
-    /root/.local/lib/python3.10/site-packages/
-
-# Override default imagemagick policies
-ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml
+# These patch are not working as intended:
+# # Add BigQuery client proxy settings
+# ENV PYTHONUSERBASE "/usr/local"
+# ADD patches/kaggle_gcp.py \
+#     patches/kaggle_secrets.py \
+#     patches/kaggle_session.py \
+#     patches/kaggle_web_client.py \ 
+#     patches/kaggle_datasets.py \
+#     patches/log.py \
+#     patches/sitecustomize.py \
+#     /root/.local/lib/python3.10/site-packages/
+
+# # Override default imagemagick policies
+# ADD patches/imagemagick-policy.xml /etc/ImageMagick-6/policy.xml
 
 # Add Kaggle module resolver
-ADD patches/kaggle_module_resolver.py /opt/conda/lib/python3.10/site-packages/tensorflow_hub/kaggle_module_resolver.py
-RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /opt/conda/lib/python3.10/site-packages/tensorflow_hub/config.py && \
-    sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /opt/conda/lib/python3.10/site-packages/tensorflow_hub/config.py && \
-    # Disable preloaded jupyter modules (they add to startup, and break when they are missing)
-    sed -i /bq_stats/d /etc/ipython/ipython_kernel_config.py && \
-    sed -i /beatrix/d /etc/ipython/ipython_kernel_config.py && \
-    sed -i /bigquery/d /etc/ipython/ipython_kernel_config.py && \
-    sed -i /sql/d /etc/ipython/ipython_kernel_config.py
-
-# Force only one libcusolver
-{{ if eq .Accelerator "gpu" }}
-RUN rm /opt/conda/bin/../lib/libcusolver.so.11 && ln -s /usr/local/cuda/lib64/libcusolver.so.11 /opt/conda/bin/../lib/libcusolver.so.11
-{{ else }}
-RUN ln -s /usr/local/cuda/lib64/libcusolver.so.11 /opt/conda/bin/../lib/libcusolver.so.11
-{{ end }}
-
-# b/270147159: conda ships with a version of libtinfo which is missing version info causing warnings, replace it with a good version.
-RUN rm /opt/conda/lib/libtinfo.so.6 && ln -s /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /opt/conda/lib/libtinfo.so.6 && \
-    # b/276358430: fix Jupyter lsp freezing up the jupyter server
-    pip install "jupyter-lsp==1.5.1"
+ADD patches/kaggle_module_resolver.py /usr/local/lib/python3.10/dist-packages/tensorflow_hub/kaggle_module_resolver.py
+RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/python3.10/dist-packages/tensorflow_hub/config.py && \
+    sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/python3.10/dist-packages/tensorflow_hub/config.py
 
 # Set backend for matplotlib
 ENV MPLBACKEND="agg" \  
@@ -626,9 +484,3 @@ LABEL tensorflow-version=$TENSORFLOW_VERSION \
 # Correlate current release with the git hash inside the kernel editor by running `!cat /etc/git_commit`.
 RUN echo "$GIT_COMMIT" > /etc/git_commit && echo "$BUILD_DATE" > /etc/build_date
 
-{{ if eq .Accelerator "gpu" }}
-# Remove the CUDA stubs.
-ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH_NO_STUBS" \
-    # Add the CUDA home.
-    CUDA_HOME=/usr/local/cuda
-{{ end }}
diff --git a/clean-layer.sh b/clean-layer.sh
index d1a048fc..3729a1fc 100755
--- a/clean-layer.sh
+++ b/clean-layer.sh
@@ -20,5 +20,3 @@ apt-get clean
 cd /usr/local/src/
 # Delete source files used for building binaries
 rm -rf /usr/local/src/*
-# Delete conda downloaded tarballs
-conda clean -y --tarballs
diff --git a/test b/test
index ef1ffe3e..e86e3198 100755
--- a/test
+++ b/test
@@ -69,7 +69,12 @@ readonly ADDITONAL_OPTS
 readonly PATTERN
 
 set -x
-docker run --rm --net=none -v /tmp/python-build:/tmp/python-build "$IMAGE_TAG" rm -rf /tmp/python-build/*
+docker run --rm --net=none \
+    -v /tmp/python-build:/tmp/python-build \
+    --entrypoint /bin/bash \
+    "$IMAGE_TAG" \
+    -c "rm -rf /tmp/python-build/*"
+
 docker rm jupyter_test || true
 mkdir -p /tmp/python-build/tmp
 mkdir -p /tmp/python-build/devshm
@@ -79,7 +84,18 @@ mkdir -p /tmp/python-build/kaggle
 # Only run Jupyter server test if no specific test pattern is specified.
 if [ $PATTERN == 'test*.py' ]; then
     # Check that Jupyter server can run; if it dies on startup, the `docker kill` command will throw an error
-    docker run -d --name=jupyter_test --read-only --net=none -e HOME=/tmp -v $PWD:/input:ro -v /tmp/python-build/working:/working -w=/working -v /tmp/python-build/tmp:/tmp -v /tmp/python-build/devshm:/dev/shm "$IMAGE_TAG" jupyter notebook --allow-root --ip="*"
+    docker run -d --name=jupyter_test --read-only --net=none \
+        -e HOME=/tmp \
+        -v $PWD:/input:ro \
+        -v /tmp/python-build/working:/working \
+        -w /working \
+        -v /tmp/python-build/tmp:/tmp \
+        -v /tmp/python-build/devshm:/dev/shm \
+        --entrypoint /bin/bash \
+        "$IMAGE_TAG" \
+        -c "jupyter notebook" \
+        --allow-root \
+        --ip="*"
     sleep 3
     docker kill jupyter_test && docker rm jupyter_test
 fi
@@ -111,6 +127,7 @@ docker run --rm -t --read-only --net=none \
     -v /tmp/python-build/tmp:/tmp -v /tmp/python-build/devshm:/dev/shm \
     -v /tmp/python-build/kaggle:/kaggle \
     -w=/working \
+    --entrypoint /bin/bash \
     $ADDITONAL_OPTS \
     "$IMAGE_TAG" \
-    /bin/bash -c "python -m unittest discover -s /input/tests -p $PATTERN -v"
+    -c "python -m unittest discover -s /input/tests -p $PATTERN -v"
\ No newline at end of file
diff --git a/tests/common.py b/tests/common.py
index 30a7bb0f..df8cb129 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -11,7 +11,6 @@ def getAcceleratorName():
     except FileNotFoundError:
         return("nvidia-smi not found.")
 
-gpu_test = unittest.skipIf(len(os.environ.get('CUDA_VERSION', '')) == 0, 'Not running GPU tests')
+gpu_test = unittest.skipIf(len(os.environ.get('COLAB_GPU', '')) == 0, 'Not running GPU tests')
 # b/342143152 P100s are slowly being unsupported in new release of popular ml tools such as RAPIDS. 
 p100_exempt = unittest.skipIf(getAcceleratorName() == "Tesla P100-PCIE-16GB", 'Not running p100 exempt tests')
-tpu_test = unittest.skipIf(len(os.environ.get('ISTPUVM', '')) == 0, 'Not running TPU tests')
diff --git a/tests/test_jax.py b/tests/test_jax.py
index b5e0898e..f0eff66d 100644
--- a/tests/test_jax.py
+++ b/tests/test_jax.py
@@ -21,5 +21,5 @@ def test_grad(self):
         self.assertEqual(0.4199743, ag)
 
     def test_backend(self):
-        expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu'
+        expected_backend = 'cpu' if len(os.environ.get('COLAB_GPU', '')) == 0 else 'gpu'
         self.assertEqual(expected_backend, jax.default_backend())
diff --git a/tests/test_user_secrets.py b/tests/test_user_secrets.py
index 67c628f7..d3e949a4 100644
--- a/tests/test_user_secrets.py
+++ b/tests/test_user_secrets.py
@@ -137,6 +137,14 @@ def call_get_secret():
                           '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"},
                           success=False)
 
+
+
+
+
+
+
+
+
     def test_set_gcloud_credentials_succeeds(self):
         secret = '{"client_id":"gcloud","type":"authorized_user","refresh_token":"refresh_token"}'
         project = 'foo'
@@ -166,6 +174,23 @@ def test_fn():
 
         self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)
 
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
     def test_set_tensorflow_credential(self):
         secret = '{"client_id":"gcloud","type":"authorized_user","refresh_token":"refresh_token"}'