-
Notifications
You must be signed in to change notification settings - Fork 28k
/
install.py
203 lines (174 loc) · 7.38 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import tarfile
import traceback
import urllib.request
from shutil import rmtree
# NOTE that we shouldn't import pyspark here because this is used in
# setup.py, and assume there's no PySpark imported.
DEFAULT_HADOOP = "hadoop3"
DEFAULT_HIVE = "hive2.3"
SUPPORTED_HADOOP_VERSIONS = ["hadoop3", "without-hadoop"]
SUPPORTED_HIVE_VERSIONS = ["hive2.3"]
UNSUPPORTED_COMBINATIONS = [] # type: ignore
def checked_package_name(spark_version, hadoop_version, hive_version):
"""
Check the generated package name, here we need to use the final hadoop version.
"""
return "%s-bin-%s" % (spark_version, hadoop_version)
def checked_versions(spark_version, hadoop_version, hive_version):
"""
Check the valid combinations of supported versions in Spark distributions.
Parameters
----------
spark_version : str
Spark version. It should be X.X.X such as '3.0.0' or spark-3.0.0.
hadoop_version : str
Hadoop version. It should be X such as '2' or 'hadoop2'.
'without' and 'without-hadoop' are supported as special keywords for Hadoop free
distribution.
hive_version : str
Hive version. It should be X.X such as '2.3' or 'hive2.3'.
Parameters
----------
tuple
fully-qualified versions of Spark, Hadoop and Hive in a tuple.
For example, spark-3.2.0, hadoop3 and hive2.3.
"""
if re.match("^[0-9]+\\.[0-9]+\\.[0-9]+$", spark_version):
spark_version = "spark-%s" % spark_version
if not spark_version.startswith("spark-"):
raise RuntimeError(
"Spark version should start with 'spark-' prefix; however, " "got %s" % spark_version
)
if hadoop_version == "without":
hadoop_version = "without-hadoop"
elif re.match("^[0-9]+$", hadoop_version):
hadoop_version = "hadoop%s" % hadoop_version
if hadoop_version not in SUPPORTED_HADOOP_VERSIONS:
raise RuntimeError(
"Spark distribution of %s is not supported. Hadoop version should be "
"one of [%s]" % (hadoop_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
)
if re.match("^[0-9]+\\.[0-9]+$", hive_version):
hive_version = "hive%s" % hive_version
if hive_version not in SUPPORTED_HIVE_VERSIONS:
raise RuntimeError(
"Spark distribution of %s is not supported. Hive version should be "
"one of [%s]" % (hive_version, ", ".join(SUPPORTED_HADOOP_VERSIONS))
)
return spark_version, convert_old_hadoop_version(spark_version, hadoop_version), hive_version
def convert_old_hadoop_version(spark_version, hadoop_version):
# check if Spark version <= 3.2, if so, convert hadoop3 to hadoop3.2 and hadoop2 to hadoop2.7
version_dict = {
"hadoop3": "hadoop3.2",
"hadoop2": "hadoop2.7",
"without": "without",
"without-hadoop": "without-hadoop",
}
spark_version_parts = re.search("^spark-([0-9]+)\\.([0-9]+)\\.[0-9]+$", spark_version)
spark_major_version = int(spark_version_parts.group(1))
spark_minor_version = int(spark_version_parts.group(2))
if spark_major_version < 3 or (spark_major_version == 3 and spark_minor_version <= 2):
hadoop_version = version_dict[hadoop_version]
return hadoop_version
def install_spark(dest, spark_version, hadoop_version, hive_version):
"""
Installs Spark that corresponds to the given Hadoop version in the current
library directory.
Parameters
----------
dest : str
The location to download and install the Spark.
spark_version : str
Spark version. It should be spark-X.X.X form.
hadoop_version : str
Hadoop version. It should be hadoopX.X
such as 'hadoop2.7' or 'without-hadoop'.
hive_version : str
Hive version. It should be hiveX.X such as 'hive2.3'.
"""
package_name = checked_package_name(spark_version, hadoop_version, hive_version)
package_local_path = os.path.join(dest, "%s.tgz" % package_name)
if "PYSPARK_RELEASE_MIRROR" in os.environ:
sites = [os.environ["PYSPARK_RELEASE_MIRROR"]]
else:
sites = get_preferred_mirrors()
print("Trying to download Spark %s from [%s]" % (spark_version, ", ".join(sites)))
pretty_pkg_name = "%s for Hadoop %s" % (
spark_version,
"Free build" if hadoop_version == "without" else hadoop_version,
)
for site in sites:
os.makedirs(dest, exist_ok=True)
url = "%s/spark/%s/%s.tgz" % (site, spark_version, package_name)
tar = None
try:
print("Downloading %s from:\n- %s" % (pretty_pkg_name, url))
download_to_file(urllib.request.urlopen(url), package_local_path)
print("Installing to %s" % dest)
tar = tarfile.open(package_local_path, "r:gz")
for member in tar.getmembers():
if member.name == package_name:
# Skip the root directory.
continue
member.name = os.path.relpath(member.name, package_name + os.path.sep)
tar.extract(member, dest)
return
except Exception:
print("Failed to download %s from %s:" % (pretty_pkg_name, url))
traceback.print_exc()
rmtree(dest, ignore_errors=True)
finally:
if tar is not None:
tar.close()
if os.path.exists(package_local_path):
os.remove(package_local_path)
raise IOError("Unable to download %s." % pretty_pkg_name)
def get_preferred_mirrors():
mirror_urls = []
for _ in range(3):
try:
response = urllib.request.urlopen(
"https://www.apache.org/dyn/closer.lua?preferred=true"
)
mirror_urls.append(response.read().decode("utf-8"))
except Exception:
# If we can't get a mirror URL, skip it. No retry.
pass
default_sites = [
"https://dlcdn.apache.org/",
"https://archive.apache.org/dist",
"https://dist.apache.org/repos/dist/release",
]
return list(set(mirror_urls)) + [x for x in default_sites if x not in mirror_urls]
def download_to_file(response, path, chunk_size=1024 * 1024):
total_size = int(response.info().get("Content-Length").strip())
bytes_so_far = 0
with open(path, mode="wb") as dest:
while True:
chunk = response.read(chunk_size)
bytes_so_far += len(chunk)
if not chunk:
break
dest.write(chunk)
print(
"Downloaded %d of %d bytes (%0.2f%%)"
% (bytes_so_far, total_size, round(float(bytes_so_far) / total_size * 100, 2))
)